@@ -3,9 +3,9 @@ module SliceMap
3
3
4
4
export mapcols, MapCols, maprows, slicemap
5
5
6
- using MacroTools, Tracker, Zygote, WeightedArrays
6
+ using MacroTools, Requires, WeightedArrays, TensorCast, Tracker
7
+
7
8
using Tracker: TrackedMatrix, track, @grad , data
8
- using Zygote: @adjoint , _zero
9
9
10
10
#= ========= Reverse, Eachslice ==========#
11
11
@@ -35,9 +35,6 @@ mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
35
35
@grad mapcols (f:: Function , M:: AbstractMatrix , args... ) =
36
36
∇mapcols ([ Tracker. forward (x -> surevec (f (x, args... )), col) for col in eachcol (data (M)) ], args)
37
37
38
- @adjoint mapcols (f:: Function , M:: AbstractMatrix , args... ) =
39
- ∇mapcols ([ Zygote. forward (x -> surevec (f (x, args... )), col) for col in eachcol (M) ], args)
40
-
41
38
function ∇mapcols (forwards, args)
42
39
reduce (hcat, data .(first .(forwards))), Δ -> begin
43
40
cols = [ data (last (fwd)(Δcol)[1 ]) for (fwd, Δcol) in zip (forwards, eachcol (data (Δ))) ]
@@ -58,16 +55,27 @@ maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
58
55
@grad maprows (f:: Function , M:: AbstractMatrix , args... ) =
59
56
∇maprows ([ Tracker. forward (x -> surevec (f (x, args... )), row) for row in eachrow (data (M)) ], args)
60
57
61
- @adjoint maprows (f:: Function , M:: AbstractMatrix , args... ) =
62
- ∇maprows ([ Zygote. forward (x -> surevec (f (x, args... )), row) for row in eachrow (M) ], args)
63
-
64
58
function ∇maprows (forwards, args)
65
59
reduce (vcat, map (transpose∘ data∘ first, forwards)), Δ -> begin
66
60
rows = [ data (last (fwd)(Δrow)[1 ]) for (fwd, Δrow) in zip (forwards, eachrow (data (Δ))) ]
67
61
(nothing , reduce (vcat, transpose .(rows)), map (_-> nothing , args)... )
68
62
end
69
63
end
70
64
65
+ """
66
+ slicemap(f, A; dims) ≈ mapslices(f, A; dims)
67
+
68
+ Like `mapcols()`, but for any slice. The function `f` must preserve shape,
69
+ e.g. `dims=(2,4)` then `f` must map matrices to matrices.
70
+
71
+ The gradient is for Zygote only.
72
+ """
73
+ function slicemap (f:: Function , A:: AbstractArray{T,N} , args... ; dims) where {T,N}
74
+ code = ntuple (d -> d in dims ? (:) : (* ), N)
75
+ B = TensorCast. sliceview (A, code)
76
+ C = [ f (slice, args... ) for slice in B ]
77
+ TensorCast. glue (C, code)
78
+ end
71
79
72
80
#= ========= Forward, Static ==========#
73
81
@@ -96,36 +104,18 @@ function _MapCols(f::Function, M::Matrix{T}, ::Val{d}, args...) where {T,d}
96
104
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (M))
97
105
B = map (col -> surevec (f (col, args... )), A)
98
106
reduce (hcat, B)
99
- # maybestaticgluecols(B)
100
- end
101
-
102
- # surevec(x::MArray) = Array(x) # avoid making a huge MArray, ad
103
- # surevecS(x::Number) = @SVector [x]
104
- # surevecS(A) = vec(A) # like surevec
105
-
106
- function maybestaticgluecols (B)
107
- TB = eltype (B)
108
- if TB <: SArray
109
- C = collect (reshape (reinterpret (eltype (TB), B),:,length (B)))
110
- elseif TB <: MArray
111
- C = reduce (hcat, Array .(B))
112
- else
113
- C = reduce (hcat, B)
114
- end
115
107
end
116
108
117
109
_MapCols (f:: Function , M:: TrackedMatrix , dval, args... ) = track (_MapCols, f, M, dval, args... )
118
110
119
111
@grad _MapCols (f:: Function , M:: TrackedMatrix , dval, args... ) = ∇MapCols (f, M, dval, args... )
120
112
121
- @adjoint _MapCols (f:: Function , M:: Matrix , dval, args... ) = ∇MapCols (f, M, dval, args... )
122
-
123
113
function ∇MapCols (f:: Function , M:: AbstractMatrix{T} , dval:: Val{d} , args... ) where {T,d}
124
114
d == size (M,1 ) || error (" expected M with $d columns" )
125
115
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (data (M)))
126
116
127
117
dualcol = SVector (ntuple (j-> ForwardDiff. Dual (0 , ntuple (i-> i== j ? 1 : 0 , dval)... ), dval))
128
- C = map (col -> surevec (f (col . + dualcol, args... )), A)
118
+ C = map (col -> surevec (f (col + dualcol, args... )), A)
129
119
130
120
Z = reduce (hcat, map (col -> ForwardDiff. value .(col), C))
131
121
@@ -144,23 +134,27 @@ function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) wh
144
134
Z, back
145
135
end
146
136
137
+ #= ========= Gradients for Zygote ==========#
138
+
139
+ # @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
140
+ # end
147
141
148
- #= ========= Gradient for eachslice / reduce ==========#
142
+ @init @require Zygote = " e88e6eb3-aa80-5325-afca-941959d7151f" include (" zygote.jl" )
143
+
144
+ #= ========= Experimenting with gradients for for eachslice / reduce ==========#
149
145
150
146
export gluecol, collecteachcol
151
147
export mapcols2, mapcols4, mapcols5, mapcols6, mapcols7
152
148
153
149
gluecol (V:: AbstractVector{<:AbstractVector} ) = reduce (hcat, V)
154
150
151
+ #=
155
152
gluecol(V::AbstractVector{<:TrackedVector}) = track(gluecol, V)
156
153
157
154
@grad function gluecol(V::AbstractVector)
158
155
gluecol(data.(V)), ΔM -> (collect(eachcol(data(ΔM))),) # doesn't work
159
156
end
160
-
161
- Zygote. @adjoint function gluecol (V:: AbstractVector )
162
- gluecol (V), ΔM -> (collect (eachcol (ΔM)),) # does work!
163
- end
157
+ =#
164
158
165
159
function mapcols2 (f, A)
166
160
cols = [A[:,c] for c= 1 : size (A,2 )]
@@ -170,34 +164,18 @@ end
170
164
171
165
# Apply that straight to reduce(hcat,...)
172
166
173
- Zygote. @adjoint function Base. reduce (:: typeof (hcat), V:: AbstractVector{<:AbstractVector} )
174
- reduce (hcat, V), dV -> (nothing , collect (eachcol (dV)),)
175
- end
176
-
177
167
function mapcols4 (f, A)
178
168
cols = [view (A,:,c) for c= 1 : size (A,2 )]
179
169
res = map (f, cols)
180
170
reduce (hcat, res)
181
171
end
182
172
183
- # Zygote doesn't understand views, but easy to fix:
184
- # https://github.com/FluxML/Zygote.jl/issues/52
185
- # now https://github.com/FluxML/Zygote.jl/pull/219
186
-
187
- Zygote. @adjoint function view (x:: AbstractArray , inds... ; kwargs... )
188
- view (x, inds... ; kwargs... ), dy -> begin
189
- dx = _zero (x)
190
- copyto! (view (dx, inds... ; kwargs... ), dy)
191
- (dx, map (_-> nothing , inds)... )
192
- end
193
- end
194
-
195
173
# Surprisingly dy for eachcol seems to know the answer?
196
174
# typeof(dy) = NamedTuple{(:f, :iter),Tuple{NamedTuple{(:A,),Tuple{Array{Float64,2}}},Array{Nothing,1}}}
197
175
# dy = (f = (A = [47.9325 51.3781
198
176
# Which means this works... but uses as much memory as gradient of array of views:
199
177
200
- #= Zygote. @adjoint function eachcol(x::AbstractMatrix)
178
+ #= @adjoint function eachcol(x::AbstractMatrix)
201
179
eachcol(x), dy -> (dy.f.A,) #= begin
202
180
@show typeof(dy) dy
203
181
dx = zero(x) .+ 0.0 # zeros(eltype(dy), size(x))
216
194
217
195
collecteachcol (x) = collect (eachcol (x))
218
196
219
- @adjoint function collecteachcol (x)
220
- collecteachcol (x), dy -> begin
221
- dx = _zero (x)
222
- foreach (copyto!, collecteachcol (dx), dy)
223
- (dx,)
224
- end
225
- end
226
-
227
197
function mapcols6 (f, A)
228
198
cols = collecteachcol (A)
229
199
res = map (f, cols)
240
210
# @adjoint Base.collect(x) = collect(x), Δ -> (Δ,)
241
211
242
212
243
- #= ========= Gradients for TensorCast's functions ==========#
244
-
245
- using TensorCast
246
-
247
- @adjoint function TensorCast. sliceview (A:: AbstractArray , code:: Tuple )
248
- TensorCast. sliceview (A, code), Δ -> begin
249
- dA = _zero (A)
250
- foreach (copyto!, TensorCast. sliceview (dA, code), Δ)
251
- (dA, nothing )
252
- end
253
- end
254
-
255
- @adjoint function TensorCast. red_glue (A:: AbstractArray , code:: Tuple )
256
- TensorCast. red_glue (A, code), Δ -> (TensorCast. sliceview (Δ, code), nothing )
257
- end
258
-
259
- @adjoint function TensorCast. copy_glue (A:: AbstractArray , code:: Tuple )
260
- TensorCast. copy_glue (A, code), Δ -> (TensorCast. sliceview (Δ, code), nothing )
261
- end
262
-
263
- """
264
- slicemap(f, A; dims) ≈ mapslices(f, A; dims)
265
-
266
- Like `mapcols()`, but for any slice. The function `f` must preserve shape,
267
- e.g. `dims=(2,4)` then `f` must map matrices to matrices.
268
-
269
- The gradient is for Zygote only.
270
- """
271
- function slicemap (f:: Function , A:: AbstractArray{T,N} , args... ; dims) where {T,N}
272
- code = ntuple (d -> d in dims ? (:) : (* ), N)
273
- B = TensorCast. sliceview (A, code)
274
- C = [ f (slice, args... ) for slice in B ]
275
- TensorCast. glue (C, code)
276
- end
277
-
278
213
end # module
0 commit comments