@@ -8,6 +8,9 @@ using MacroTools, Requires, TensorCast, JuliennedArrays
8
8
using Tracker
9
9
using Tracker: TrackedMatrix, track, @grad , data
10
10
11
+ using ZygoteRules
12
+ using ZygoteRules: pullback, @adjoint
13
+
11
14
#= ========= Reverse, Eachslice ==========#
12
15
13
16
"""
@@ -36,6 +39,9 @@ _mapcols(map::Function, f::Function, M::TrackedMatrix, args...) = track(_mapcols
36
39
@grad _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
37
40
∇mapcols (map, map (col -> Tracker. forward (x -> surevec (f (x, args... )), col), eachcol (data (M))), args... )
38
41
42
+ @adjoint _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
43
+ ∇mapcols (map, map (col -> ZygoteRules. pullback (x -> surevec (f (x, args... )), col), eachcol (M)), args)
44
+
39
45
function ∇mapcols (bigmap, forwards, args... )
40
46
reduce (hcat, map (data∘ first, forwards)), Δ -> begin
41
47
cols = bigmap ((fwd, Δcol) -> data (last (fwd)(Δcol)[1 ]), forwards, eachcol (data (Δ)))
@@ -56,6 +62,9 @@ maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
56
62
@grad maprows (f:: Function , M:: AbstractMatrix , args... ) =
57
63
∇maprows (map (row -> Tracker. forward (x -> surevec (f (x, args... )), row), eachrow (data (M))), args)
58
64
65
+ @adjoint maprows (f:: Function , M:: AbstractMatrix , args... ) =
66
+ ∇maprows (map (row -> ZygoteRules. pullback (x -> surevec (f (x, args... )), row), eachrow (M)), args)
67
+
59
68
function ∇maprows (forwards, args)
60
69
reduce (vcat, map (transpose∘ data∘ first, forwards)), Δ -> begin
61
70
rows = map ((fwd, Δrow) -> data (last (fwd)(Δrow)[1 ]), forwards, eachrow (data (Δ)))
@@ -77,6 +86,7 @@ function slicemap(f::Function, A::AbstractArray{T,N}, args...; dims) where {T,N}
77
86
C = [ f (slice, args... ) for slice in B ]
78
87
TensorCast. glue (C, code)
79
88
end
89
+ # TODO switch to JuliennedArrays, then rm TensorCast dep
80
90
81
91
#= ========= Forward, Static ==========#
82
92
@@ -111,6 +121,9 @@ _MapCols(map::Function, f::Function, M::TrackedMatrix, dval, args...) =
111
121
@grad _MapCols (map:: Function , f:: Function , M:: TrackedMatrix , dval, args... ) =
112
122
∇MapCols (map, f, M, dval, args... )
113
123
124
+ @adjoint _MapCols (map:: Function , f:: Function , M:: Matrix , dval, args... ) =
125
+ ∇MapCols (map, f, M, dval, args... )
126
+
114
127
function ∇MapCols (bigmap:: Function , f:: Function , M:: AbstractMatrix{T} , dval:: Val{d} , args... ) where {T,d}
115
128
d == size (M,1 ) || error (" expected M with $d rows" )
116
129
k = size (M,2 )
142
155
143
156
# @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
144
157
145
- @init @require Zygote = " e88e6eb3-aa80-5325-afca-941959d7151f" include (" zygote.jl" )
158
+ # @init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("zygote.jl")
159
+ # Now using ZygoteRules instead, mapcols etc above.
160
+
161
+ #= TensorCast =#
162
+ # These could move there, TODO
163
+
164
+ @adjoint TensorCast. sliceview (A:: AbstractArray , code:: Tuple ) =
165
+ TensorCast. sliceview (A, code), Δ -> (TensorCast. glue (Δ, code), nothing )
166
+
167
+ @adjoint TensorCast. red_glue (A:: AbstractArray , code:: Tuple ) =
168
+ TensorCast. red_glue (A, code), Δ -> (TensorCast. sliceview (Δ, code), nothing )
169
+
170
+ @adjoint TensorCast. copy_glue (A:: AbstractArray , code:: Tuple ) =
171
+ TensorCast. copy_glue (A, code), Δ -> (TensorCast. sliceview (Δ, code), nothing )
172
+
173
+ #= JuliennedArrays =#
174
+
175
+ @adjoint JuliennedArrays. Slices (whole, along... ) =
176
+ Slices (whole, along... ), Δ -> (Align (Δ, along... ), map (_-> nothing , along)... )
177
+
178
+ @adjoint JuliennedArrays. Align (whole, along... ) =
179
+ Align (whole, along... ), Δ -> (Slices (Δ, along... ), map (_-> nothing , along)... )
180
+
181
+ #= Base =#
182
+
183
+ @adjoint Base. reduce (:: typeof (hcat), V:: AbstractVector{<:AbstractVector} ) =
184
+ reduce (hcat, V), dV -> (nothing , collect (eachcol (dV)),)
185
+
146
186
147
187
#= ========= Experimenting with gradients for for eachslice / reduce ==========#
148
188
@@ -159,6 +199,9 @@ gluecol(V::AbstractVector{<:TrackedVector}) = track(gluecol, V)
159
199
end
160
200
=#
161
201
202
+ @adjoint gluecol (V:: AbstractVector ) =
203
+ gluecol (V), ΔM -> (collect (eachcol (ΔM)),) # does work!
204
+
162
205
function mapcols2 (f, A)
163
206
cols = [A[:,c] for c= 1 : size (A,2 )]
164
207
res = f .(cols)
197
240
198
241
collecteachcol (x) = collect (eachcol (x))
199
242
243
+ @adjoint function collecteachcol (x)
244
+ collecteachcol (x), dy -> begin
245
+ dx = _zero (x) # _zero is not in ZygoteRules, TODO
246
+ foreach (copyto!, collecteachcol (dx), dy)
247
+ (dx,)
248
+ end
249
+ end
250
+
200
251
function mapcols6 (f, A)
201
252
cols = collecteachcol (A)
202
253
res = map (f, cols)
0 commit comments