1
1
2
2
module SliceMap
3
3
4
- export MapCols, mapcols , maprows
4
+ export mapcols, MapCols , maprows, slicemap
5
5
6
-
7
- #= ========= Gradient Macro ==========#
8
-
9
- using MacroTools, Tracker, Zygote
6
+ using MacroTools, Tracker, Zygote, WeightedArrays
10
7
using Tracker: TrackedMatrix, track, @grad , data
11
8
using Zygote: @adjoint , _zero
12
9
13
- macro gradadjoint (ex)
14
- quote
15
- # $(Zygote.gradm(ex)) # this doesn't work
16
- $ (trackergrad (ex))
17
- end
18
- end
19
-
20
- # Copied from https://github.com/FluxML/Tracker.jl/blob/master/src/Tracker.jl#L55
21
- function trackergrad (ex)
22
- @capture (shortdef (ex), (name_ (args__) = body_) |
23
- (name_ (args__) where {T__} = body_)) || error (" Need a function definition" )
24
- T == nothing && (T = [])
25
- isexpr (name, :(:: )) || (name = :(:: typeof ($ name)))
26
- insert! (args, 1 + isexpr (args[1 ], :parameters ) , name)
27
- MacroTools. @q (Tracker. _forward ($ (args... )) where $ (T... ) = $ body) |> esc
28
- end
29
-
30
-
31
10
#= ========= Reverse, Eachslice ==========#
32
11
33
- using WeightedArrays
34
-
35
12
"""
36
- mapcols(f, m::Matrix, args...) = reduce(hcat, f(c, args...) for c in eachcol(M) )
13
+ mapcols(f, m) ≈ mapreduce(f, hcat, eachcol(m)) ≈ mapslices(f, m, dims=1 )
37
14
38
- When `m::TrackedMatrix`, it saves the backward function for each slice.
39
- All further arguments are scalar constants, i.e. they do not get sliced/iterated (unlike `map`)
40
- nor are their gradients tracked.
15
+ This is a more efficient version of the functions on the right.
16
+ For `f(x::Vector)::Matrix` it reshapes like `mapslices(vec∘f, m, dims=1)`.
17
+
18
+ It provides a gradient for Tracker and Zygote, saving the backward function for each slice.
19
+
20
+ Any arguments after the matrix are passed to `f` as scalars, i.e.
21
+ `mapcols(f, m, args...) = reduce(hcat, f(col, args...) for col in eeachcol(m))`.
22
+ They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
41
23
"""
42
24
mapcols (f:: Function , M:: AbstractMatrix , args... ) =
43
25
reduce (hcat, [ surevec (f (col, args... )) for col in eachcol (M) ])
@@ -50,44 +32,42 @@ surevec(A) = vec(A) # to allow f vector -> matrix, by reshaping
50
32
51
33
mapcols (f:: Function , M:: TrackedMatrix , args... ) = track (mapcols, f, M, args... )
52
34
53
- @grad function mapcols (f:: Function , M:: AbstractMatrix , args... )
54
- res = [ Tracker. forward (x -> surevec (f (x, args... )), col) for col in eachcol (data (M)) ]
55
- fwd = reduce (hcat, data .(first .(res)))
56
- function back (Δ)
57
- cols = [ data ((last (res[c]))(Δcol)[1 ]) for (c, Δcol) in enumerate (eachcol (data (Δ))) ]
58
- ∇M = reduce (hcat, cols)
59
- (nothing , ∇M, map (_-> nothing , args)... )
60
- end
61
- fwd, back
62
- end
35
+ @grad mapcols (f:: Function , M:: AbstractMatrix , args... ) =
36
+ ∇mapcols ([ Tracker. forward (x -> surevec (f (x, args... )), col) for col in eachcol (data (M)) ], args)
63
37
64
- @adjoint function mapcols (f:: Function , M:: Matrix , args... )
65
- res = [ Zygote. forward (x -> surevec (f (x, args... )), col) for col in eachcol (M) ]
66
- fwd = reduce (hcat, first .(res))
67
- function back (Δ )
68
- cols = [ ( last (res[c]))(Δcol)[ 1 ] for (c, Δcol) in enumerate ( eachcol (Δ)) ]
69
- ∇M = reduce (hcat, cols)
70
- (nothing , ∇M , map (_-> nothing , args)... )
38
+ @adjoint mapcols (f:: Function , M:: AbstractMatrix , args... ) =
39
+ ∇mapcols ( [ Zygote. forward (x -> surevec (f (x, args... )), col) for col in eachcol (M) ], args)
40
+
41
+ function ∇mapcols (forwards, args )
42
+ reduce (hcat, data .( first .(forwards))), Δ -> begin
43
+ cols = [ data ( last (fwd)(Δcol)[ 1 ]) for (fwd, Δcol) in zip (forwards, eachcol ( data (Δ))) ]
44
+ (nothing , reduce (hcat, cols) , map (_-> nothing , args)... )
71
45
end
72
- fwd, back
73
46
end
74
47
48
+ """
49
+ maprows(f, M) ≈ mapslices(f, M, dims=2)
50
+
51
+ Like `mapcols()` but for rows.
52
+ """
75
53
maprows (f:: Function , M:: AbstractMatrix , args... ) =
76
54
reduce (vcat, [ surerow (f (col, args... )) for col in eachrow (M) ])
77
55
78
56
surerow (x) = transpose (surevec (x))
79
57
80
58
maprows (f:: Function , M:: TrackedMatrix , args... ) = track (maprows, f, M, args... )
81
59
82
- @grad function maprows (f:: Function , M:: AbstractMatrix , args... )
83
- res = [ Tracker. forward (x -> surerow (f (x, args... )), row) for row in eachrow (data (M)) ]
84
- fwd = reduce (vcat, data .(first .(res)))
85
- function back (Δ)
86
- rows = [ data ((last (res[r]))(Δrow)[1 ]) for (r, Δrow) in enumerate (eachrow (data (Δ))) ]
87
- ∇M = reduce (vcat, rows)
88
- (nothing , ∇M, map (_-> nothing , args)... )
60
+ @grad maprows (f:: Function , M:: AbstractMatrix , args... ) =
61
+ ∇maprows ([ Tracker. forward (x -> surerow (f (x, args... )), row) for row in eachrow (data (M)) ], args)
62
+
63
+ @adjoint maprows (f:: Function , M:: AbstractMatrix , args... ) =
64
+ ∇maprows ([ Zygote. forward (x -> surerow (f (x, args... )), row) for row in eachrow (M) ], args)
65
+
66
+ function ∇maprows (forwards, args)
67
+ reduce (vcat, data .(first .(forwards))), Δ -> begin
68
+ rows = [ data (last (fwd)(Δrow)[1 ]) for (fwd, Δrow) in zip (forwards, eachrow (data (Δ))) ]
69
+ (nothing , reduce (vcat, rows), map (_-> nothing , args)... )
89
70
end
90
- fwd, back
91
71
end
92
72
93
73
@@ -100,16 +80,12 @@ struct MapCols{d} end
100
80
"""
101
81
MapCols{d}(f, m::Matrix, args...)
102
82
103
- Expects `f(::SVector{d}, args...)` and maps this over the columns, `d = size(M,1)`.
104
- Doesn't expect `f` to return a staticarray, just an array.
105
-
106
- When `m::TrackedMatrix`, it uses `ForwardDiff` to calculate the gradient of each slice.
107
- The second point of keeping one type parameter is that the dual numbers needed depend on this.
83
+ Similar to `mapcols(f, m, args...)`, but slices `m` into `SVector{d}` columns.
84
+ Their length `d = size(M,1)` should ideally be provided for type-stability, but is not required.
108
85
109
- MapCols{d}(f, m::Weighted, args...)
110
- Takes `m.weights` along for the ride.
86
+ The gradient for Tracker and Zygote uses `ForwardDiff` on each slice.
111
87
"""
112
- MapCols (f:: Function , M:: WeightedArrays.MaybeWeightedMatrix , args... ) =
88
+ MapCols (f:: Function , M:: AT , args... ) where {AT <: WeightedArrays.MaybeWeightedMatrix } =
113
89
MapCols {size(M,1)} (f, M, args... )
114
90
115
91
MapCols {d} (f:: Function , M:: WeightedMatrix , args... ) where {d} =
@@ -119,14 +95,15 @@ MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} = _MapCols(f, M, V
119
95
120
96
function _MapCols (f:: Function , M:: Matrix{T} , :: Val{d} , args... ) where {T,d}
121
97
d == size (M,1 ) || error (" expected M with $d columns" )
122
- # @cast A[c]{r:d} := M[r,c] assert
123
98
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (M))
124
99
B = map (col -> surevec (f (col, args... )), A)
125
100
reduce (hcat, B)
126
101
# maybestaticgluecols(B)
127
102
end
128
103
129
104
# surevec(x::MArray) = Array(x) # avoid making a huge MArray, ad
105
+ # surevecS(x::Number) = @SVector [x]
106
+ # surevecS(A) = vec(A) # like surevec
130
107
131
108
function maybestaticgluecols (B)
132
109
TB = eltype (B)
@@ -139,50 +116,33 @@ function maybestaticgluecols(B)
139
116
end
140
117
end
141
118
142
- # surevecS(x::Number) = @SVector [x]
143
- # surevecS(A) = vec(A) # like surevec
144
-
145
119
_MapCols (f:: Function , M:: TrackedMatrix , dval, args... ) = track (_MapCols, f, M, dval, args... )
146
120
147
121
@grad _MapCols (f:: Function , M:: TrackedMatrix , dval, args... ) = ∇MapCols (f, M, dval, args... )
148
122
149
123
@adjoint _MapCols (f:: Function , M:: Matrix , dval, args... ) = ∇MapCols (f, M, dval, args... )
150
124
151
125
function ∇MapCols (f:: Function , M:: AbstractMatrix{T} , dval:: Val{d} , args... ) where {T,d}
152
-
153
126
d == size (M,1 ) || error (" expected M with $d columns" )
154
- # @cast A[c]{r:d} := data(M)[r,c]
155
127
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (data (M)))
156
128
157
129
dualcol = SVector (ntuple (j-> ForwardDiff. Dual (0 , ntuple (i-> i== j ? 1 : 0 , dval)... ), dval))
158
-
159
- # C = [ surevec(f(col .+ dualcol, args...)) for col in A ]
160
130
C = map (col -> surevec (f (col .+ dualcol, args... )), A)
161
131
162
- # Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ])
163
132
Z = reduce (hcat, map (col -> ForwardDiff. value .(col), C))
164
133
165
134
function back (ΔZ)
166
- # accum = zero(eltype(data(ΔZ)))
167
- # ∇M = similar(data(M)) .+ zero(first(data(ΔZ)))
168
135
∇M = zeros (eltype (data (ΔZ)), size (M))
169
136
@inbounds for c= 1 : size (M,2 )
170
137
part = ForwardDiff. partials .(C[c])
171
138
for r= 1 : d
172
- # ∇M[r,c] = 0
173
- # accum = 0
174
139
for i= 1 : size (ΔZ,1 )
175
140
∇M[r,c] += data (ΔZ)[i,c] * part[i]. values[r]
176
- # parti = ForwardDiff.partials(C[c][i])
177
- # ∇M[r,c] += data(ΔZ)[i,c] * parti.values[r]
178
- # accum += data(ΔZ)[i,c] * part[i].values[r]
179
141
end
180
- # ∇M[r,c] = accum
181
142
end
182
143
end
183
144
(nothing , ∇M, nothing , map (_-> nothing , args)... )
184
145
end
185
-
186
146
Z, back
187
147
end
188
148
298
258
TensorCast. red_glue (A, code), Δ -> (TensorCast. sliceview (Δ, code), nothing )
299
259
end
300
260
261
+ @adjoint function TensorCast. copy_glue (A:: AbstractArray , code:: Tuple )
262
+ TensorCast. copy_glue (A, code), Δ -> (TensorCast. sliceview (Δ, code), nothing )
263
+ end
264
+
265
+ """
266
+ slicemap(f, A; dims) ≈ mapslices(f, A; dims)
267
+
268
+ Like `mapcols()`, but for any slice. Gradient is for Zygote only.
269
+ """
270
+ function slicemap (f:: Function , A:: AbstractArray{T,N} , args... ; dims) where {T,N}
271
+ code = ntuple (d -> d in dims ? (:) : (* ), N)
272
+ B = TensorCast. sliceview (A, code)
273
+ C = [ f (slice, args... ) for slice in B ]
274
+ TensorCast. glue (C, code)
275
+ end
301
276
302
277
end # module
0 commit comments