1
1
2
2
module SliceMap
3
3
4
- export mapcols, MapCols, maprows, slicemap, ThreadMapCols
4
+ export mapcols, MapCols, maprows, slicemap, tmapcols, ThreadMapCols
5
5
6
6
using MacroTools, Requires, WeightedArrays, TensorCast, JuliennedArrays
7
7
@@ -22,24 +22,27 @@ Any arguments after the matrix are passed to `f` as scalars, i.e.
22
22
`mapcols(f, m, args...) = reduce(hcat, f(col, args...) for col in eeachcol(m))`.
23
23
They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
24
24
"""
25
- mapcols (f:: Function , M:: AbstractMatrix , args... ) =
26
- reduce (hcat, [ surevec ( f (col , args... )) for col in eachcol (M) ] )
25
+ mapcols (f:: Function , M, args... ) = _mapcols (map, f, M, args ... )
26
+ tmapcols (f :: Function , M , args... ) = _mapcols (threadmap, f, M, args ... )
27
27
28
- mapcols (f:: Function , M:: WeightedMatrix , args... ) =
29
- Weighted (mapcols (f, M. array, args... ), M. weights, M. opt)
28
+ _mapcols (map:: Function , f:: Function , M:: WeightedMatrix , args... ) =
29
+ Weighted (_mapcols (map, f, M. array, args... ), M. weights, M. opt)
30
+
31
+ _mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
32
+ reduce (hcat, map (col -> surevec (f (col, args... )), eachcol (M)))
30
33
31
34
surevec (x:: Number ) = [x] # to allow f vector -> scalar, as mapslices does
32
35
surevec (A) = vec (A) # to allow f vector -> matrix, by reshaping
33
36
34
- mapcols ( f:: Function , M:: TrackedMatrix , args... ) = track (mapcols , f, M, args... )
37
+ _mapcols (map :: Function , f:: Function , M:: TrackedMatrix , args... ) = track (_mapcols, map , f, M, args... )
35
38
36
- @grad mapcols ( f:: Function , M:: AbstractMatrix , args... ) =
37
- ∇mapcols ([ Tracker. forward (x -> surevec (f (x, args... )), col) for col in eachcol (data (M)) ] , args)
39
+ @grad _mapcols (map :: Function , f:: Function , M:: AbstractMatrix , args... ) =
40
+ ∇mapcols (map, map (col -> Tracker. forward (x -> surevec (f (x, args... )), col), eachcol (data (M))) , args)
38
41
39
- function ∇mapcols (forwards, args)
40
- reduce (hcat, data .( first .( forwards) )), Δ -> begin
41
- cols = [ data (last (fwd)(Δcol)[1 ]) for (fwd, Δcol) in zip ( forwards, eachcol (data (Δ))) ]
42
- (nothing , reduce (hcat, cols), map (_-> nothing , args)... )
42
+ function ∇mapcols (bigmap, forwards, args)
43
+ reduce (hcat, map (data ∘ first, forwards)), Δ -> begin
44
+ cols = bigmap ((fwd, Δcol) -> data (last (fwd)(Δcol)[1 ]), forwards, eachcol (data (Δ)))
45
+ (nothing , nothing , reduce (hcat, cols), map (_-> nothing , args)... )
43
46
end
44
47
end
45
48
49
52
Like `mapcols()` but for rows.
50
53
"""
51
54
maprows (f:: Function , M:: AbstractMatrix , args... ) =
52
- reduce (vcat, [ transpose (surevec (f (col, args... ))) for col in eachrow (M) ] )
55
+ reduce (vcat, map (col -> transpose (surevec (f (col, args... ))), eachrow (M)) )
53
56
54
57
maprows (f:: Function , M:: TrackedMatrix , args... ) = track (maprows, f, M, args... )
55
58
56
59
@grad maprows (f:: Function , M:: AbstractMatrix , args... ) =
57
- ∇maprows ([ Tracker. forward (x -> surevec (f (x, args... )), row) for row in eachrow (data (M)) ] , args)
60
+ ∇maprows (map (row -> Tracker. forward (x -> surevec (f (x, args... )), row), eachrow (data (M))) , args)
58
61
59
62
function ∇maprows (forwards, args)
60
63
reduce (vcat, map (transpose∘ data∘ first, forwards)), Δ -> begin
61
- rows = [ data (last (fwd)(Δrow)[1 ]) for (fwd, Δrow) in zip ( forwards, eachrow (data (Δ))) ]
64
+ rows = map ((fwd, Δrow) -> data (last (fwd)(Δrow)[1 ]), forwards, eachrow (data (Δ)))
62
65
(nothing , reduce (vcat, transpose .(rows)), map (_-> nothing , args)... )
63
66
end
64
67
end
67
70
slicemap(f, A; dims) ≈ mapslices(f, A; dims)
68
71
69
72
Like `mapcols()`, but for any slice. The function `f` must preserve shape,
70
- e.g. `dims=(2,4)` then `f` must map matrices to matrices.
73
+ e.g. if `dims=(2,4)` then `f` must map matrices to matrices.
71
74
72
75
The gradient is for Zygote only.
73
76
"""
@@ -99,28 +102,27 @@ MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
99
102
Weighted (MapCols {d} (f, M. array, args... ), M. weights, M. opt)
100
103
101
104
MapCols {d} (f:: Function , M:: AbstractMatrix , args... ) where {d} =
102
- _MapCols (f, M, Val (d), Val ( false ), args... )
105
+ _MapCols (map, f, M, Val (d), args... )
103
106
104
- function _MapCols (f:: Function , M:: Matrix{T} , :: Val{d} , tval :: Val , args... ) where {T,d}
107
+ function _MapCols (map :: Function , f:: Function , M:: Matrix{T} , :: Val{d} , args... ) where {T,d}
105
108
d == size (M,1 ) || error (" expected M with $d columns" )
106
109
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (M))
107
- B = maybethreadmap (col -> surevec (f (col, args... )), A, tval )
110
+ B = map (col -> surevec (f (col, args... )), A)
108
111
reduce (hcat, B)
109
112
end
110
113
111
- _MapCols (f:: Function , M:: TrackedMatrix , dval, tval, args... ) =
112
- track (_MapCols, f, M, dval, tval, args... )
113
-
114
- @grad _MapCols (f:: Function , M:: TrackedMatrix , dval, tval, args... ) =
115
- ∇MapCols (f, M, dval, tval, args... )
114
+ _MapCols (map:: Function , f:: Function , M:: TrackedMatrix , dval, args... ) =
115
+ track (_MapCols, map, f, M, dval, args... )
116
116
117
- function ∇MapCols (f:: Function , M:: AbstractMatrix{T} , dval:: Val{d} , tval:: Val , args... ) where {T,d}
117
+ @grad _MapCols (map:: Function , f:: Function , M:: TrackedMatrix , dval, args... ) =
118
+ ∇MapCols (map, f, M, dval, args... )
118
119
120
+ function ∇MapCols (bigmap:: Function , f:: Function , M:: AbstractMatrix{T} , dval:: Val{d} , args... ) where {T,d}
119
121
d == size (M,1 ) || error (" expected M with $d columns" )
120
122
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (data (M)))
121
123
122
124
dualcol = SVector (ntuple (j-> ForwardDiff. Dual (0 , ntuple (i-> i== j ? 1 : 0 , dval)... ), dval))
123
- C = maybethreadmap (col -> surevec (f (col + dualcol, args... )), A, tval )
125
+ C = bigmap (col -> surevec (f (col + dualcol, args... )), A)
124
126
125
127
Z = reduce (hcat, map (col -> ForwardDiff. value .(col), C))
126
128
@@ -134,15 +136,14 @@ function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, tval::Val,
134
136
end
135
137
end
136
138
end
137
- (nothing , ∇M, nothing , nothing , map (_-> nothing , args)... )
139
+ (nothing , nothing , ∇M , nothing , map (_-> nothing , args)... )
138
140
end
139
141
Z, back
140
142
end
141
143
142
144
#= ========= Gradients for Zygote ==========#
143
145
144
- # @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
145
- # end
146
+ # @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
146
147
147
148
@init @require Zygote = " e88e6eb3-aa80-5325-afca-941959d7151f" include (" zygote.jl" )
148
149
@@ -219,24 +220,35 @@ end
219
220
# What KissThreading does is much more complicated, perhaps worth investigating:
220
221
# https://github.com/mohamed82008/KissThreading.jl/blob/master/src/KissThreading.jl
221
222
222
- function threadmap (f:: Function , v:: AbstractVector )
223
- length (v)== 0 && error (" can't map over empty vector, sorry" )
224
- out1 = f (first (v))
225
- _threadmap (out1, f, v)
223
+ """
224
+ threadmap(f, A)
225
+ threadmap(f, A, B)
226
+
227
+ Simple version of `map` using a `Threads.@threads` loop;
228
+ only for vectors & only two of them, of nonzero length,
229
+ with all outputs having the same type.
230
+ """
231
+ function threadmap (f:: Function , vw:: AbstractVector... )
232
+ length (first (vw))== 0 && error (" can't map over empty vector, sorry" )
233
+ length (vw)== 2 && (isequal (length .(vw)... ) || error (" lengths must be equal" ))
234
+ out1 = f (first .(vw)... )
235
+ _threadmap (out1, f, vw... )
226
236
end
227
237
# NB barrier
228
- function _threadmap (out1, f, v )
229
- out = Vector {typeof(out1)} (undef, length (v ))
238
+ function _threadmap (out1, f, vw ... )
239
+ out = Vector {typeof(out1)} (undef, length (first (vw) ))
230
240
out[1 ] = out1
231
- Threads. @threads for i= 2 : length (v )
232
- @inbounds out[i] = f (v[i] )
241
+ Threads. @threads for i= 2 : length (first (vw) )
242
+ @inbounds out[i] = f (getindex .(vw, i) ... )
233
243
end
234
244
out
235
245
end
236
246
237
- # This switch is fast inside ∇MapCols, after many attempts!
238
- maybethreadmap (f, v, :: Val{true} ) = threadmap (f, v)
239
- maybethreadmap (f, v, :: Val{false} ) = map (f, v)
247
+ # Collect generators to allow indexing
248
+ threadmap (f:: Function , v) = threadmap (f, collect (v))
249
+ threadmap (f:: Function , v, w) = threadmap (f, collect (v), collect (w))
250
+ threadmap (f:: Function , v, w:: AbstractVector ) = threadmap (f, collect (v), w)
251
+ threadmap (f:: Function , v:: AbstractVector , w) = threadmap (f, v, collect (w))
240
252
241
253
struct ThreadMapCols{d} end
242
254
@@ -252,7 +264,7 @@ ThreadMapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
252
264
Weighted (ThreadMapCols {d} (f, M. array, args... ), M. weights, M. opt)
253
265
254
266
ThreadMapCols {d} (f:: Function , M:: AbstractMatrix , args... ) where {d} =
255
- _MapCols (f, M, Val (d), Val ( true ), args... )
267
+ _MapCols (threadmap, f, M, Val (d), args... )
256
268
257
269
258
270
end # module
0 commit comments