@@ -3,22 +3,49 @@ module SliceMap
3
3
4
4
export MapCols, mapcols
5
5
6
+ #= ========= Gradient Macro ==========#
7
+
8
+ using MacroTools, Tracker, Zygote
9
+ using Tracker: TrackedMatrix, track, @grad , data
10
+ using Zygote: @adjoint , _zero
11
+
12
+ macro gradadjoint (ex)
13
+ quote
14
+ # $(Zygote.gradm(ex)) # this doesn't work
15
+ $ (trackergrad (ex))
16
+ end
17
+ end
18
+
19
+ # Copied from https://github.com/FluxML/Tracker.jl/blob/master/src/Tracker.jl#L55
20
+ function trackergrad (ex)
21
+ @capture (shortdef (ex), (name_ (args__) = body_) |
22
+ (name_ (args__) where {T__} = body_)) || error (" Need a function definition" )
23
+ T == nothing && (T = [])
24
+ isexpr (name, :(:: )) || (name = :(:: typeof ($ name)))
25
+ insert! (args, 1 + isexpr (args[1 ], :parameters ) , name)
26
+ MacroTools. @q (Tracker. _forward ($ (args... )) where $ (T... ) = $ body) |> esc
27
+ end
28
+
6
29
#= ========= Reverse, Eachslice ==========#
7
30
31
+ using WeightedArrays
32
+
8
33
"""
9
34
mapcols(f, m::Matrix, args...) = reduce(hcat, f(c, args...) for c in eachcol(M))
10
35
11
36
When `m::TrackedMatrix`, it saves the backward function for each slice.
37
+ All further arguments are scalar constants, i.e. they do not get sliced/iterated (unlike `map`)
38
+ nor are their gradients tracked.
12
39
"""
13
- mapcols (f:: Function , M:: Matrix , args... ) =
40
+ mapcols (f:: Function , M:: AbstractMatrix , args... ) =
14
41
reduce (hcat, [ rvec (f (col, args... )) for col in eachcol (M) ])
15
42
16
- using Tracker
17
- using Tracker : TrackedMatrix, track, @grad , data
43
+ mapcols (f :: Function , M :: WeightedMatrix , args ... ) =
44
+ Weighted ( mapcols (f, M . array, args ... ), M . weights, M . opt)
18
45
19
46
mapcols (f:: Function , M:: TrackedMatrix , args... ) = track (mapcols, f, M, args... )
20
47
21
- @grad function mapcols (f:: Function , M:: TrackedMatrix , args... )
48
+ @gradadjoint function mapcols (f:: Function , M:: AbstractMatrix , args... )
22
49
res = [ Tracker. forward (x -> rvec (f (x, args... )), col) for col in eachcol (data (M)) ]
23
50
fwd = reduce (hcat, data .(first .(res)))
24
51
function back (Δ)
@@ -29,7 +56,7 @@ mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
29
56
fwd, back
30
57
end
31
58
32
- using Zygote
59
+ # @gradadjoint not yet working
33
60
Zygote. @adjoint function mapcols (f:: Function , M:: Matrix , args... )
34
61
res = [ Zygote. forward (x -> rvec (f (x, args... )), col) for col in eachcol (data (M)) ]
35
62
fwd = reduce (hcat, data .(first .(res)))
@@ -41,6 +68,23 @@ Zygote.@adjoint function mapcols(f::Function, M::Matrix, args...)
41
68
fwd, back
42
69
end
43
70
71
+ maprows (f:: Function , M:: AbstractMatrix , args... ) =
72
+ reduce (vcat, [ tvec (f (col, args... )) for col in eachrow (M) ])
73
+
74
+ maprows (f:: Function , M:: TrackedMatrix , args... ) = track (maprows, f, M, args... )
75
+
76
+ @gradadjoint function maprows (f:: Function , M:: AbstractMatrix , args... )
77
+ res = [ Tracker. forward (x -> tvec (f (x, args... )), row) for row in eachrow (data (M)) ]
78
+ fwd = reduce (vcat, data .(first .(res)))
79
+ function back (Δ)
80
+ rows = [ data ((last (res[r]))(Δrow)[1 ]) for (r, Δrow) in enumerate (eachrow (data (Δ))) ]
81
+ ∇M = reduce (vcat, rows)
82
+ (nothing , ∇M, map (_-> nothing , args)... )
83
+ end
84
+ fwd, back
85
+ end
86
+
87
+
44
88
#= ========= Forward, Static ==========#
45
89
46
90
using TensorCast, StaticArrays, WeightedArrays
@@ -80,6 +124,7 @@ rvec(x::Number) = [x] # to allow for f vector -> scalar, as mapslices does
80
124
rvec (x:: StaticArray ) = vec (Array (x)) # to avoid creating a giant staticarray, as reduce(hcat would otherwise do
81
125
rvec (A) = vec (A) # LinearAlgebra.
82
126
127
+ tvec (x) = transpose (rvec (x))
83
128
84
129
using ForwardDiff
85
130
@@ -111,6 +156,8 @@ MapCols{d}(f::Function, M::TrackedMatrix, args...) where {d} = track(MapCols, f,
111
156
Z, back
112
157
end
113
158
159
+ # TODO make a _MapCols which always takes Val(d), then unite these
160
+
114
161
Zygote. @adjoint function MapCols {d} (f:: Function , M:: Matrix , args... ) where {d} # no dval!
115
162
116
163
@cast A[c]{r: d} := M[r,c]
@@ -137,4 +184,94 @@ Zygote.@adjoint function MapCols{d}(f::Function, M::Matrix, args...) where {d} #
137
184
Z, back
138
185
end
139
186
187
+ #= ========= Gradient for eachslice / reduce ==========#
188
+
189
+ export gluecol, mapcols2, mapcols4, mapcols5, mapcols6, mapcols7
190
+
191
+ gluecol (V:: AbstractVector{<:AbstractVector} ) = reduce (hcat, V)
192
+
193
+ gluecol (V:: AbstractVector{<:TrackedVector} ) = track (gluecol, V)
194
+
195
+ @grad function gluecol (V:: AbstractVector )
196
+ gluecol (data .(V)), ΔM -> (collect (eachcol (data (ΔM))),) # doesn't work
197
+ end
198
+
199
+ Zygote. @adjoint function gluecol (V:: AbstractVector )
200
+ gluecol (V), ΔM -> (collect (eachcol (ΔM)),) # does work!
201
+ end
202
+
203
+ function mapcols2 (f, A)
204
+ cols = [A[:,c] for c= 1 : size (A,2 )]
205
+ res = f .(cols)
206
+ gluecol (res)
207
+ end
208
+
209
+ # Apply that straight to reduce(hcat,...)
210
+
211
+ Zygote. @adjoint function Base. reduce (:: typeof (hcat), V:: AbstractVector{<:AbstractVector} )
212
+ reduce (hcat, V), dV -> (nothing , collect (eachcol (dV)),)
213
+ end
214
+
215
+ function mapcols4 (f, A)
216
+ cols = [view (A,:,c) for c= 1 : size (A,2 )]
217
+ res = map (f, cols)
218
+ reduce (hcat, res)
219
+ end
220
+
221
+ # Zygote doesn't understand views, but easy to fix:
222
+ # https://github.com/FluxML/Zygote.jl/issues/52
223
+ # now https://github.com/FluxML/Zygote.jl/pull/219
224
+
225
+ Zygote. @adjoint function view (x:: AbstractArray , inds... ; kwargs... )
226
+ view (x, inds... ; kwargs... ), dy -> begin
227
+ dx = _zero (x)
228
+ copyto! (view (dx, inds... ; kwargs... ), dy)
229
+ (dx, map (_-> nothing , inds)... )
230
+ end
231
+ end
232
+
233
+ # Surprisingly dy for eachcol seems to know the answer?
234
+ # typeof(dy) = NamedTuple{(:f, :iter),Tuple{NamedTuple{(:A,),Tuple{Array{Float64,2}}},Array{Nothing,1}}}
235
+ # dy = (f = (A = [47.9325 51.3781
236
+ # Which means this works... but uses as much memory as gradient of array of views:
237
+
238
+ Zygote. @adjoint function eachcol (x:: AbstractMatrix )
239
+ eachcol (x), dy -> (dy. f. A,) #= begin
240
+ @show typeof(dy) dy
241
+ dx = zero(x) .+ 0.0 # zeros(eltype(dy), size(x))
242
+ foreach(copyto!, eachcol(dx), dy)
243
+ (dx,)
244
+ end =#
245
+ end
246
+
247
+ # @adjoint eachcol(x) = eachcol(x), dy -> (dy.f.A,)
248
+
249
+ function mapcols5 (f, A)
250
+ cols = collect (eachcol (A))
251
+ res = map (f, cols)
252
+ reduce (hcat, res)
253
+ end
254
+
255
+ collecteachcol (x) = collect (eachcol (x))
256
+
257
+ Zygote. @adjoint function collecteachcol (x)
258
+ collecteachcol (x), dy -> begin
259
+ dx = _zero (x)
260
+ foreach (copyto!, collecteachcol (dx), dy)
261
+ (dx,)
262
+ end
263
+ end
264
+
265
+ function mapcols6 (f, A)
266
+ cols = collecteachcol (A)
267
+ res = map (f, cols)
268
+ reduce (hcat, res)
269
+ end
270
+
271
+ # function mapcols7(f, A)
272
+ # cols = eachcol(A) # without collect. Zygote.gradient -> StackOverflowError
273
+ # res = map(f, cols)
274
+ # reduce(hcat, res)
275
+ # end
276
+
140
277
end # module
0 commit comments