1
1
2
2
module SliceMap
3
3
4
- export MapCols, mapcols
4
+ export MapCols, mapcols, maprows
5
+
5
6
6
7
#= ========= Gradient Macro ==========#
7
8
@@ -26,6 +27,7 @@ function trackergrad(ex)
26
27
MacroTools. @q (Tracker. _forward ($ (args... )) where $ (T... ) = $ body) |> esc
27
28
end
28
29
30
+
29
31
#= ========= Reverse, Eachslice ==========#
30
32
31
33
using WeightedArrays
@@ -38,15 +40,18 @@ All further arguments are scalar constants, i.e. they do not get sliced/iterated
38
40
nor are their gradients tracked.
39
41
"""
40
42
mapcols (f:: Function , M:: AbstractMatrix , args... ) =
41
- reduce (hcat, [ rvec (f (col, args... )) for col in eachcol (M) ])
43
+ reduce (hcat, [ surevec (f (col, args... )) for col in eachcol (M) ])
42
44
43
45
mapcols (f:: Function , M:: WeightedMatrix , args... ) =
44
46
Weighted (mapcols (f, M. array, args... ), M. weights, M. opt)
45
47
48
+ surevec (x:: Number ) = [x] # to allow f vector -> scalar, as mapslices does
49
+ surevec (A) = vec (A) # to allow f vector -> matrix, by reshaping
50
+
46
51
mapcols (f:: Function , M:: TrackedMatrix , args... ) = track (mapcols, f, M, args... )
47
52
48
- @gradadjoint function mapcols (f:: Function , M:: AbstractMatrix , args... )
49
- res = [ Tracker. forward (x -> rvec (f (x, args... )), col) for col in eachcol (data (M)) ]
53
+ @grad function mapcols (f:: Function , M:: AbstractMatrix , args... )
54
+ res = [ Tracker. forward (x -> surevec (f (x, args... )), col) for col in eachcol (data (M)) ]
50
55
fwd = reduce (hcat, data .(first .(res)))
51
56
function back (Δ)
52
57
cols = [ data ((last (res[c]))(Δcol)[1 ]) for (c, Δcol) in enumerate (eachcol (data (Δ))) ]
@@ -56,25 +61,26 @@ mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
56
61
fwd, back
57
62
end
58
63
59
- # @gradadjoint not yet working
60
- Zygote. @adjoint function mapcols (f:: Function , M:: Matrix , args... )
61
- res = [ Zygote. forward (x -> rvec (f (x, args... )), col) for col in eachcol (data (M)) ]
62
- fwd = reduce (hcat, data .(first .(res)))
64
+ @adjoint function mapcols (f:: Function , M:: Matrix , args... )
65
+ res = [ Zygote. forward (x -> surevec (f (x, args... )), col) for col in eachcol (M) ]
66
+ fwd = reduce (hcat, first .(res))
63
67
function back (Δ)
64
- cols = [ data (( last (res[c]))(Δcol)[1 ]) for (c, Δcol) in enumerate (eachcol (data (Δ) )) ]
68
+ cols = [ ( last (res[c]))(Δcol)[1 ] for (c, Δcol) in enumerate (eachcol (Δ )) ]
65
69
∇M = reduce (hcat, cols)
66
70
(nothing , ∇M, map (_-> nothing , args)... )
67
71
end
68
72
fwd, back
69
73
end
70
74
71
75
maprows (f:: Function , M:: AbstractMatrix , args... ) =
72
- reduce (vcat, [ tvec (f (col, args... )) for col in eachrow (M) ])
76
+ reduce (vcat, [ surerow (f (col, args... )) for col in eachrow (M) ])
77
+
78
+ surerow (x) = transpose (surevec (x))
73
79
74
80
maprows (f:: Function , M:: TrackedMatrix , args... ) = track (maprows, f, M, args... )
75
81
76
- @gradadjoint function maprows (f:: Function , M:: AbstractMatrix , args... )
77
- res = [ Tracker. forward (x -> tvec (f (x, args... )), row) for row in eachrow (data (M)) ]
82
+ @grad function maprows (f:: Function , M:: AbstractMatrix , args... )
83
+ res = [ Tracker. forward (x -> surerow (f (x, args... )), row) for row in eachrow (data (M)) ]
78
84
fwd = reduce (vcat, data .(first .(res)))
79
85
function back (Δ)
80
86
rows = [ data ((last (res[r]))(Δrow)[1 ]) for (r, Δrow) in enumerate (eachrow (data (Δ))) ]
87
93
88
94
#= ========= Forward, Static ==========#
89
95
90
- using TensorCast, StaticArrays , WeightedArrays
96
+ using StaticArrays, ForwardDiff , WeightedArrays
91
97
92
98
struct MapCols{d} end
93
99
@@ -106,48 +112,72 @@ Takes `m.weights` along for the ride.
106
112
MapCols (f:: Function , M:: WeightedArrays.MaybeWeightedMatrix , args... ) =
107
113
MapCols {size(M,1)} (f, M, args... )
108
114
109
- MapCols {d} (f:: Function , M:: WeightedMatrix , args... ) where {d} =
115
+ MapCols {d} (f:: Function , M:: WeightedMatrix , args... ) where {d} =
110
116
Weighted (MapCols {d} (f, M. array, args... ), M. weights, M. opt)
111
117
112
- function MapCols {d} (f:: Function , M:: Matrix , args... ) where {d}
113
- @cast A[c]{r: d} := M[r,c] assert
114
- reduce (hcat, [ rvec (f (acol, args... )) for acol in A ])
118
+ MapCols {d} (f:: Function , M:: AbstractMatrix , args... ) where {d} = _MapCols (f, M, Val (d), args... )
119
+
120
+ function _MapCols (f:: Function , M:: Matrix{T} , :: Val{d} , args... ) where {T,d}
121
+ d == size (M,1 ) || error (" expected M with $d columns" )
122
+ # @cast A[c]{r:d} := M[r,c] assert
123
+ A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (M))
124
+ B = map (col -> surevec (f (col, args... )), A)
125
+ reduce (hcat, B)
126
+ # maybestaticgluecols(B)
127
+ end
115
128
116
- # TODO : call some function which static-glues if possible...
117
- # TensorCast.auto_glue(map(col -> rvec(f(col, args...)), A), (:,*))
129
+ # surevec(x::MArray) = Array(x) # avoid making a huge MArray, ad
118
130
119
- # TODO : can I thread this? Is it even safe to do so?
120
- # https://github.com/mohamed82008/KissThreading.jl
131
+ function maybestaticgluecols (B)
132
+ TB = eltype (B)
133
+ if TB <: SArray
134
+ C = collect (reshape (reinterpret (eltype (TB), B),:,length (B)))
135
+ elseif TB <: MArray
136
+ C = reduce (hcat, Array .(B))
137
+ else
138
+ C = reduce (hcat, B)
139
+ end
121
140
end
122
141
123
- rvec (x:: Number ) = [x] # to allow for f vector -> scalar, as mapslices does
124
- rvec (x:: StaticArray ) = vec (Array (x)) # to avoid creating a giant staticarray, as reduce(hcat would otherwise do
125
- rvec (A) = vec (A) # LinearAlgebra.
142
+ # surevecS(x::Number) = @SVector [x]
143
+ # surevecS(A) = vec(A) # like surevec
144
+
145
+ _MapCols (f:: Function , M:: TrackedMatrix , dval, args... ) = track (_MapCols, f, M, dval, args... )
126
146
127
- tvec (x ) = transpose ( rvec (x) )
147
+ @grad _MapCols (f :: Function , M :: TrackedMatrix , dval, args ... ) = ∇MapCols (f, M, dval, args ... )
128
148
129
- using ForwardDiff
149
+ @adjoint _MapCols (f :: Function , M :: Matrix , dval, args ... ) = ∇MapCols (f, M, dval, args ... )
130
150
131
- MapCols {d} (f:: Function , M:: TrackedMatrix , args ... ) where {d} = track (MapCols, f, M, Val (d) , args... )
151
+ function ∇ MapCols (f:: Function , M:: AbstractMatrix{T} , dval :: Val{d} , args... ) where {T,d}
132
152
133
- @grad function MapCols (f:: Function , M:: TrackedMatrix , dval:: Val{d} , args... ) where {d}
153
+ d == size (M,1 ) || error (" expected M with $d columns" )
154
+ # @cast A[c]{r:d} := data(M)[r,c]
155
+ A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (data (M)))
134
156
135
- @cast A[c]{r: d} := M. data[r,c]
136
157
dualcol = SVector (ntuple (j-> ForwardDiff. Dual (0 , ntuple (i-> i== j ? 1 : 0 , dval)... ), dval))
137
158
138
- C = [ rvec (f (acol .+ dualcol, args... )) for acol in A ]
159
+ # C = [ surevec(f(col .+ dualcol, args...)) for col in A ]
160
+ C = map (col -> surevec (f (col .+ dualcol, args... )), A)
139
161
140
- Z = reduce (hcat, [ ForwardDiff. value .(full) for full in C ]) # full is not an SVector here
162
+ # Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ])
163
+ Z = reduce (hcat, map (col -> ForwardDiff. value .(col), C))
141
164
142
165
function back (ΔZ)
143
- ∇M = similar (data (M)) .+ zero (first (data (ΔZ)))
166
+ # accum = zero(eltype(data(ΔZ)))
167
+ # ∇M = similar(data(M)) .+ zero(first(data(ΔZ)))
168
+ ∇M = zeros (eltype (data (ΔZ)), size (M))
144
169
@inbounds for c= 1 : size (M,2 )
145
170
part = ForwardDiff. partials .(C[c])
146
171
for r= 1 : d
147
- ∇M[r,c] = 0
172
+ # ∇M[r,c] = 0
173
+ # accum = 0
148
174
for i= 1 : size (ΔZ,1 )
149
175
∇M[r,c] += data (ΔZ)[i,c] * part[i]. values[r]
176
+ # parti = ForwardDiff.partials(C[c][i])
177
+ # ∇M[r,c] += data(ΔZ)[i,c] * parti.values[r]
178
+ # accum += data(ΔZ)[i,c] * part[i].values[r]
150
179
end
180
+ # ∇M[r,c] = accum
151
181
end
152
182
end
153
183
(nothing , ∇M, nothing , map (_-> nothing , args)... )
@@ -156,37 +186,11 @@ MapCols{d}(f::Function, M::TrackedMatrix, args...) where {d} = track(MapCols, f,
156
186
Z, back
157
187
end
158
188
159
- # TODO make a _MapCols which always takes Val(d), then unite these
160
-
161
- Zygote. @adjoint function MapCols {d} (f:: Function , M:: Matrix , args... ) where {d} # no dval!
162
-
163
- @cast A[c]{r: d} := M[r,c]
164
- dualcol = SVector (ntuple (j-> ForwardDiff. Dual (0 , ntuple (i-> i== j ? 1 : 0 , Val (d))... ), Val (d)))
165
-
166
- C = [ rvec (f (acol .+ dualcol, args... )) for acol in A ]
167
-
168
- Z = reduce (hcat, [ ForwardDiff. value .(full) for full in C ])
169
-
170
- function back (ΔZ)
171
- ∇M = similar (data (M)) .+ zero (first (data (ΔZ)))
172
- @inbounds for c= 1 : size (M,2 )
173
- part = ForwardDiff. partials .(C[c])
174
- for r= 1 : d
175
- ∇M[r,c] = 0
176
- for i= 1 : size (ΔZ,1 )
177
- ∇M[r,c] += data (ΔZ)[i,c] * part[i]. values[r]
178
- end
179
- end
180
- end
181
- (nothing , ∇M, map (_-> nothing , args)... ) # changed!
182
- end
183
-
184
- Z, back
185
- end
186
189
187
190
#= ========= Gradient for eachslice / reduce ==========#
188
191
189
- export gluecol, mapcols2, mapcols4, mapcols5, mapcols6, mapcols7
192
+ export gluecol, collecteachcol
193
+ export mapcols2, mapcols4, mapcols5, mapcols6, mapcols7
190
194
191
195
gluecol (V:: AbstractVector{<:AbstractVector} ) = reduce (hcat, V)
192
196
@@ -235,14 +239,14 @@ end
235
239
# dy = (f = (A = [47.9325 51.3781
236
240
# Which means this works... but uses as much memory as gradient of array of views:
237
241
238
- Zygote. @adjoint function eachcol (x:: AbstractMatrix )
242
+ #= Zygote.@adjoint function eachcol(x::AbstractMatrix)
239
243
eachcol(x), dy -> (dy.f.A,) #= begin
240
244
@show typeof(dy) dy
241
245
dx = zero(x) .+ 0.0 # zeros(eltype(dy), size(x))
242
246
foreach(copyto!, eachcol(dx), dy)
243
247
(dx,)
244
248
end =#
245
- end
249
+ end=#
246
250
247
251
# @adjoint eachcol(x) = eachcol(x), dy -> (dy.f.A,)
248
252
254
258
255
259
collecteachcol (x) = collect (eachcol (x))
256
260
257
- Zygote . @adjoint function collecteachcol (x)
261
+ @adjoint function collecteachcol (x)
258
262
collecteachcol (x), dy -> begin
259
263
dx = _zero (x)
260
264
foreach (copyto!, collecteachcol (dx), dy)
274
278
# reduce(hcat, res)
275
279
# end
276
280
281
+ # Following a suggestion? Doesn't help.
282
+ # @adjoint Base.collect(x) = collect(x), Δ -> (Δ,)
283
+
284
+
285
+ #= ========= Gradients for TensorCast's functions ==========#
286
+
287
+ using TensorCast
288
+
289
+ @adjoint function TensorCast. sliceview (A:: AbstractArray , code:: Tuple )
290
+ TensorCast. sliceview (A, code), Δ -> begin
291
+ dA = _zero (A)
292
+ foreach (copyto!, TensorCast. sliceview (dA, code), Δ)
293
+ (dA, nothing )
294
+ end
295
+ end
296
+
297
+ @adjoint function TensorCast. red_glue (A:: AbstractArray , code:: Tuple )
298
+ TensorCast. red_glue (A, code), Δ -> (TensorCast. sliceview (Δ, code), nothing )
299
+ end
300
+
301
+
277
302
end # module
0 commit comments