Skip to content

Commit 002092a

Browse files
author
Michael Abbott
committed
two plus
1 parent dd3b0a8 commit 002092a

File tree

2 files changed

+112
-74
lines changed

2 files changed

+112
-74
lines changed

README.md

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ It would be nice if [Flux](https://github.com/FluxML/Flux.jl) worked with `mapsl
44
or with something generalising that. This package has some quick attempts:
55

66
```julia
7-
mat = rand(1:99, 3,10)
7+
mat = rand(1:9, 3,10)
88
fun(x) = 2 .+ x.^2
99
mapslices(fun, mat, dims=1)
1010

@@ -31,32 +31,45 @@ mat1k = rand(3,1000);
3131

3232
@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms
3333
@btime mapcols(fun, $mat1k) # 399.016 μs
34-
@btime MapCols{3}(fun, $mat1k) # 46.733 μs
35-
@btime MapCols(fun, $mat1k) # 59.471 μs without size
34+
@btime MapCols{3}(fun, $mat1k) # 15.564 μs
35+
@btime MapCols(fun, $mat1k) # 16.774 μs without size
3636

3737
@btime ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), $mat1k); # 372.705 ms
3838
@btime Tracker.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 70.203 ms
39-
@btime Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 255.032 μs, 690.09 KiB
39+
@btime Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 146.561 μs, 330.51 KiB
4040
@btime Zygote.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 20.018 ms, 3.82 MiB
41-
@btime Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 354.112 μs
41+
@btime Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 245.550 μs
4242
```
4343

4444
Of course `mapslices()` does things other than columns of matrices.
4545
Most of which can be done better with `eachslice()` and `reduce(hcat,...)`,
46-
maybe with some thought one could just write gradients for those.
46+
maybe with some thought one could just write gradients for those...
4747

48-
Perhaps done. The views of `eachcol()` have quite inefficient gradients,
49-
but `collecteachcol()` is efficient:
48+
Perhaps this is done. The views of `eachcol()` have quite inefficient gradients,
49+
because for each `view()` they make a fresh `zero(A)`, but `collecteachcol()` is efficient:
5050

5151
```julia
5252
@btime Zygote.gradient(m -> sum(sin, mapcols4(fun, m)), $mat1k); # 45.616 ms, 49.49 MiB
5353
@btime Zygote.gradient(m -> sum(sin, mapcols6(fun, m)), $mat1k); # 18.655 ms, 3.37 MiB
5454
```
5555

56-
<!--
5756
Or for the slice/glue functions in [TensorCast](https://github.com/mcabbott/TensorCast.jl),
5857
which now does some mapslices things (and will soon do many more) by chaining such functions.
59-
-->
58+
59+
```julia
60+
using TensorCast
61+
@cast [i,j] := fun(mat[:,j])[i] # same as mapcols
62+
63+
tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i]
64+
Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]
65+
66+
@btime tcm($mat1k) # 407.176 μs
67+
@btime Zygote.gradient(m -> sum(sin, tcm(m)), $mat1k) # 19.086 ms
68+
69+
ten = rand(1:9, 3,10,2)
70+
@cast zed[i,j,k] := fun(ten[i,:,k])[j]
71+
Zygote.gradient(m -> sum(sin, @cast zed[i,j,k] := fun(m[i,:,k])[j] nolazy), ten)[1]
72+
```
6073

6174
Issues about mapslices:
6275
* https://github.com/FluxML/Zygote.jl/issues/92

src/SliceMap.jl

Lines changed: 89 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11

22
module SliceMap
33

4-
export MapCols, mapcols
4+
export MapCols, mapcols, maprows
5+
56

67
#========== Gradient Macro ==========#
78

@@ -26,6 +27,7 @@ function trackergrad(ex)
2627
MacroTools.@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
2728
end
2829

30+
2931
#========== Reverse, Eachslice ==========#
3032

3133
using WeightedArrays
@@ -38,15 +40,18 @@ All further arguments are scalar constants, i.e. they do not get sliced/iterated
3840
nor are their gradients tracked.
3941
"""
4042
mapcols(f::Function, M::AbstractMatrix, args...) =
41-
reduce(hcat, [ rvec(f(col, args...)) for col in eachcol(M) ])
43+
reduce(hcat, [ surevec(f(col, args...)) for col in eachcol(M) ])
4244

4345
mapcols(f::Function, M::WeightedMatrix, args...) =
4446
Weighted(mapcols(f, M.array, args...), M.weights, M.opt)
4547

48+
surevec(x::Number) = [x] # to allow f vector -> scalar, as mapslices does
49+
surevec(A) = vec(A) # to allow f vector -> matrix, by reshaping
50+
4651
mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
4752

48-
@gradadjoint function mapcols(f::Function, M::AbstractMatrix, args...)
49-
res = [ Tracker.forward(x -> rvec(f(x, args...)), col) for col in eachcol(data(M)) ]
53+
@grad function mapcols(f::Function, M::AbstractMatrix, args...)
54+
res = [ Tracker.forward(x -> surevec(f(x, args...)), col) for col in eachcol(data(M)) ]
5055
fwd = reduce(hcat, data.(first.(res)))
5156
function back(Δ)
5257
cols = [ data((last(res[c]))(Δcol)[1]) for (c, Δcol) in enumerate(eachcol(data(Δ))) ]
@@ -56,25 +61,26 @@ mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
5661
fwd, back
5762
end
5863

59-
# @gradadjoint not yet working
60-
Zygote.@adjoint function mapcols(f::Function, M::Matrix, args...)
61-
res = [ Zygote.forward(x -> rvec(f(x, args...)), col) for col in eachcol(data(M)) ]
62-
fwd = reduce(hcat, data.(first.(res)))
64+
@adjoint function mapcols(f::Function, M::Matrix, args...)
65+
res = [ Zygote.forward(x -> surevec(f(x, args...)), col) for col in eachcol(M) ]
66+
fwd = reduce(hcat, first.(res))
6367
function back(Δ)
64-
cols = [ data((last(res[c]))(Δcol)[1]) for (c, Δcol) in enumerate(eachcol(data(Δ))) ]
68+
cols = [ (last(res[c]))(Δcol)[1] for (c, Δcol) in enumerate(eachcol(Δ)) ]
6569
∇M = reduce(hcat, cols)
6670
(nothing, ∇M, map(_->nothing, args)...)
6771
end
6872
fwd, back
6973
end
7074

7175
maprows(f::Function, M::AbstractMatrix, args...) =
72-
reduce(vcat, [ tvec(f(col, args...)) for col in eachrow(M) ])
76+
reduce(vcat, [ surerow(f(col, args...)) for col in eachrow(M) ])
77+
78+
surerow(x) = transpose(surevec(x))
7379

7480
maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
7581

76-
@gradadjoint function maprows(f::Function, M::AbstractMatrix, args...)
77-
res = [ Tracker.forward(x -> tvec(f(x, args...)), row) for row in eachrow(data(M)) ]
82+
@grad function maprows(f::Function, M::AbstractMatrix, args...)
83+
res = [ Tracker.forward(x -> surerow(f(x, args...)), row) for row in eachrow(data(M)) ]
7884
fwd = reduce(vcat, data.(first.(res)))
7985
function back(Δ)
8086
rows = [ data((last(res[r]))(Δrow)[1]) for (r, Δrow) in enumerate(eachrow(data(Δ))) ]
@@ -87,7 +93,7 @@ end
8793

8894
#========== Forward, Static ==========#
8995

90-
using TensorCast, StaticArrays, WeightedArrays
96+
using StaticArrays, ForwardDiff, WeightedArrays
9197

9298
struct MapCols{d} end
9399

@@ -106,48 +112,72 @@ Takes `m.weights` along for the ride.
106112
MapCols(f::Function, M::WeightedArrays.MaybeWeightedMatrix, args...) =
107113
MapCols{size(M,1)}(f, M, args...)
108114

109-
MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
115+
MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
110116
Weighted(MapCols{d}(f, M.array, args...), M.weights, M.opt)
111117

112-
function MapCols{d}(f::Function, M::Matrix, args...) where {d}
113-
@cast A[c]{r:d} := M[r,c] assert
114-
reduce(hcat, [ rvec(f(acol, args...)) for acol in A ])
118+
MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} = _MapCols(f, M, Val(d), args...)
119+
120+
function _MapCols(f::Function, M::Matrix{T}, ::Val{d}, args...) where {T,d}
121+
d == size(M,1) || error("expected M with $d columns")
122+
# @cast A[c]{r:d} := M[r,c] assert
123+
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(M))
124+
B = map(col -> surevec(f(col, args...)), A)
125+
reduce(hcat, B)
126+
# maybestaticgluecols(B)
127+
end
115128

116-
# TODO: call some function which static-glues if possible...
117-
# TensorCast.auto_glue(map(col -> rvec(f(col, args...)), A), (:,*))
129+
# surevec(x::MArray) = Array(x) # avoid making a huge MArray, ad
118130

119-
# TODO: can I thread this? Is it even safe to do so?
120-
# https://github.com/mohamed82008/KissThreading.jl
131+
function maybestaticgluecols(B)
132+
TB = eltype(B)
133+
if TB <: SArray
134+
C = collect(reshape(reinterpret(eltype(TB), B),:,length(B)))
135+
elseif TB <: MArray
136+
C = reduce(hcat, Array.(B))
137+
else
138+
C = reduce(hcat, B)
139+
end
121140
end
122141

123-
rvec(x::Number) = [x] # to allow for f vector -> scalar, as mapslices does
124-
rvec(x::StaticArray) = vec(Array(x)) # to avoid creating a giant staticarray, as reduce(hcat would otherwise do
125-
rvec(A) = vec(A) # LinearAlgebra.
142+
# surevecS(x::Number) = @SVector [x]
143+
# surevecS(A) = vec(A) # like surevec
144+
145+
_MapCols(f::Function, M::TrackedMatrix, dval, args...) = track(_MapCols, f, M, dval, args...)
126146

127-
tvec(x) = transpose(rvec(x))
147+
@grad _MapCols(f::Function, M::TrackedMatrix, dval, args...) = ∇MapCols(f, M, dval, args...)
128148

129-
using ForwardDiff
149+
@adjoint _MapCols(f::Function, M::Matrix, dval, args...) = ∇MapCols(f, M, dval, args...)
130150

131-
MapCols{d}(f::Function, M::TrackedMatrix, args...) where {d} = track(MapCols, f, M, Val(d), args...)
151+
function MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) where {T,d}
132152

133-
@grad function MapCols(f::Function, M::TrackedMatrix, dval::Val{d}, args...) where {d}
153+
d == size(M,1) || error("expected M with $d columns")
154+
# @cast A[c]{r:d} := data(M)[r,c]
155+
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M)))
134156

135-
@cast A[c]{r:d} := M.data[r,c]
136157
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
137158

138-
C = [ rvec(f(acol .+ dualcol, args...)) for acol in A ]
159+
# C = [ surevec(f(col .+ dualcol, args...)) for col in A ]
160+
C = map(col -> surevec(f(col .+ dualcol, args...)), A)
139161

140-
Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ]) # full is not an SVector here
162+
# Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ])
163+
Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C))
141164

142165
function back(ΔZ)
143-
∇M = similar(data(M)) .+ zero(first(data(ΔZ)))
166+
# accum = zero(eltype(data(ΔZ)))
167+
# ∇M = similar(data(M)) .+ zero(first(data(ΔZ)))
168+
∇M = zeros(eltype(data(ΔZ)), size(M))
144169
@inbounds for c=1:size(M,2)
145170
part = ForwardDiff.partials.(C[c])
146171
for r=1:d
147-
∇M[r,c] = 0
172+
# ∇M[r,c] = 0
173+
# accum = 0
148174
for i=1:size(ΔZ,1)
149175
∇M[r,c] += data(ΔZ)[i,c] * part[i].values[r]
176+
# parti = ForwardDiff.partials(C[c][i])
177+
# ∇M[r,c] += data(ΔZ)[i,c] * parti.values[r]
178+
# accum += data(ΔZ)[i,c] * part[i].values[r]
150179
end
180+
# ∇M[r,c] = accum
151181
end
152182
end
153183
(nothing, ∇M, nothing, map(_->nothing, args)...)
@@ -156,37 +186,11 @@ MapCols{d}(f::Function, M::TrackedMatrix, args...) where {d} = track(MapCols, f,
156186
Z, back
157187
end
158188

159-
# TODO make a _MapCols which always takes Val(d), then unite these
160-
161-
Zygote.@adjoint function MapCols{d}(f::Function, M::Matrix, args...) where {d} # no dval!
162-
163-
@cast A[c]{r:d} := M[r,c]
164-
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, Val(d))...), Val(d)))
165-
166-
C = [ rvec(f(acol .+ dualcol, args...)) for acol in A ]
167-
168-
Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ])
169-
170-
function back(ΔZ)
171-
∇M = similar(data(M)) .+ zero(first(data(ΔZ)))
172-
@inbounds for c=1:size(M,2)
173-
part = ForwardDiff.partials.(C[c])
174-
for r=1:d
175-
∇M[r,c] = 0
176-
for i=1:size(ΔZ,1)
177-
∇M[r,c] += data(ΔZ)[i,c] * part[i].values[r]
178-
end
179-
end
180-
end
181-
(nothing, ∇M, map(_->nothing, args)...) # changed!
182-
end
183-
184-
Z, back
185-
end
186189

187190
#========== Gradient for eachslice / reduce ==========#
188191

189-
export gluecol, mapcols2, mapcols4, mapcols5, mapcols6, mapcols7
192+
export gluecol, collecteachcol
193+
export mapcols2, mapcols4, mapcols5, mapcols6, mapcols7
190194

191195
gluecol(V::AbstractVector{<:AbstractVector}) = reduce(hcat, V)
192196

@@ -235,14 +239,14 @@ end
235239
# dy = (f = (A = [47.9325 51.3781
236240
# Which means this works... but uses as much memory as gradient of array of views:
237241

238-
Zygote.@adjoint function eachcol(x::AbstractMatrix)
242+
#=Zygote.@adjoint function eachcol(x::AbstractMatrix)
239243
eachcol(x), dy -> (dy.f.A,) #= begin
240244
@show typeof(dy) dy
241245
dx = zero(x) .+ 0.0 # zeros(eltype(dy), size(x))
242246
foreach(copyto!, eachcol(dx), dy)
243247
(dx,)
244248
end =#
245-
end
249+
end=#
246250

247251
# @adjoint eachcol(x) = eachcol(x), dy -> (dy.f.A,)
248252

@@ -254,7 +258,7 @@ end
254258

255259
collecteachcol(x) = collect(eachcol(x))
256260

257-
Zygote.@adjoint function collecteachcol(x)
261+
@adjoint function collecteachcol(x)
258262
collecteachcol(x), dy -> begin
259263
dx = _zero(x)
260264
foreach(copyto!, collecteachcol(dx), dy)
@@ -274,4 +278,25 @@ end
274278
# reduce(hcat, res)
275279
# end
276280

281+
# Following a suggestion? Doesn't help.
282+
# @adjoint Base.collect(x) = collect(x), Δ -> (Δ,)
283+
284+
285+
#========== Gradients for TensorCast's functions ==========#
286+
287+
using TensorCast
288+
289+
@adjoint function TensorCast.sliceview(A::AbstractArray, code::Tuple)
290+
TensorCast.sliceview(A, code), Δ -> begin
291+
dA = _zero(A)
292+
foreach(copyto!, TensorCast.sliceview(dA, code), Δ)
293+
(dA, nothing)
294+
end
295+
end
296+
297+
@adjoint function TensorCast.red_glue(A::AbstractArray, code::Tuple)
298+
TensorCast.red_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
299+
end
300+
301+
277302
end # module

0 commit comments

Comments
 (0)