1
1
2
2
module SliceMap
3
3
4
- export mapcols, MapCols, maprows, slicemap
4
+ export mapcols, MapCols, maprows, slicemap, ThreadMapCols
5
5
6
6
using MacroTools, Requires, WeightedArrays, TensorCast, JuliennedArrays
7
7
@@ -98,25 +98,29 @@ MapCols(f::Function, M::AT, args...) where {AT<:WeightedArrays.MaybeWeightedMatr
98
98
MapCols {d} (f:: Function , M:: WeightedMatrix , args... ) where {d} =
99
99
Weighted (MapCols {d} (f, M. array, args... ), M. weights, M. opt)
100
100
101
- MapCols {d} (f:: Function , M:: AbstractMatrix , args... ) where {d} = _MapCols (f, M, Val (d), args... )
101
+ MapCols {d} (f:: Function , M:: AbstractMatrix , args... ) where {d} =
102
+ _MapCols (f, M, Val (d), Val (false ), args... )
102
103
103
- function _MapCols (f:: Function , M:: Matrix{T} , :: Val{d} , args... ) where {T,d}
104
+ function _MapCols (f:: Function , M:: Matrix{T} , :: Val{d} , tval :: Val , args... ) where {T,d}
104
105
d == size (M,1 ) || error (" expected M with $d columns" )
105
106
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (M))
106
- B = map (col -> surevec (f (col, args... )), A)
107
+ B = maybethreadmap (col -> surevec (f (col, args... )), A, tval )
107
108
reduce (hcat, B)
108
109
end
109
110
110
- _MapCols (f:: Function , M:: TrackedMatrix , dval, args... ) = track (_MapCols, f, M, dval, args... )
111
+ _MapCols (f:: Function , M:: TrackedMatrix , dval, tval, args... ) =
112
+ track (_MapCols, f, M, dval, tval, args... )
111
113
112
- @grad _MapCols (f:: Function , M:: TrackedMatrix , dval, args... ) = ∇MapCols (f, M, dval, args... )
114
+ @grad _MapCols (f:: Function , M:: TrackedMatrix , dval, tval, args... ) =
115
+ ∇MapCols (f, M, dval, tval, args... )
116
+
117
+ function ∇MapCols (f:: Function , M:: AbstractMatrix{T} , dval:: Val{d} , tval:: Val , args... ) where {T,d}
113
118
114
- function ∇MapCols (f:: Function , M:: AbstractMatrix{T} , dval:: Val{d} , args... ) where {T,d}
115
119
d == size (M,1 ) || error (" expected M with $d columns" )
116
120
A = reinterpret (SArray{Tuple{d}, T, 1 , d}, vec (data (M)))
117
121
118
122
dualcol = SVector (ntuple (j-> ForwardDiff. Dual (0 , ntuple (i-> i== j ? 1 : 0 , dval)... ), dval))
119
- C = map (col -> surevec (f (col + dualcol, args... )), A)
123
+ C = maybethreadmap (col -> surevec (f (col + dualcol, args... )), A, tval )
120
124
121
125
Z = reduce (hcat, map (col -> ForwardDiff. value .(col), C))
122
126
@@ -130,7 +134,7 @@ function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) wh
130
134
end
131
135
end
132
136
end
133
- (nothing , ∇M, nothing , map (_-> nothing , args)... )
137
+ (nothing , ∇M, nothing , nothing , map (_-> nothing , args)... )
134
138
end
135
139
Z, back
136
140
end
210
214
# Following a suggestion? Doesn't help.
211
215
# @adjoint Base.collect(x) = collect(x), Δ -> (Δ,)
212
216
217
+ #= ========= Threaded Map ==========#
218
+
219
+ # What KissThreading does is much more complicated, perhaps worth investigating:
220
+ # https://github.com/mohamed82008/KissThreading.jl/blob/master/src/KissThreading.jl
221
+
222
+ function threadmap (f:: Function , v:: AbstractVector )
223
+ length (v)== 0 && error (" can't map over empty vector, sorry" )
224
+ out1 = f (first (v))
225
+ _threadmap (out1, f, v)
226
+ end
227
+ # NB barrier
228
+ function _threadmap (out1, f, v)
229
+ out = Vector {typeof(out1)} (undef, length (v))
230
+ out[1 ] = out1
231
+ Threads. @threads for i= 2 : length (v)
232
+ @inbounds out[i] = f (v[i])
233
+ end
234
+ out
235
+ end
236
+
237
+ # This switch is fast inside ∇MapCols, after many attempts!
238
+ maybethreadmap (f, v, :: Val{true} ) = threadmap (f, v)
239
+ maybethreadmap (f, v, :: Val{false} ) = map (f, v)
240
+
241
+ struct ThreadMapCols{d} end
242
+
243
+ """
244
+ ThreadMapCols{d}(f, m::Matrix, args...)
245
+
246
+ Like `MapCols` but with multi-threading!
247
+ """
248
+ ThreadMapCols (f:: Function , M:: AT , args... ) where {AT<: WeightedArrays.MaybeWeightedMatrix } =
249
+ ThreadMapCols {size(M,1)} (f, M, args... )
250
+
251
+ ThreadMapCols {d} (f:: Function , M:: WeightedMatrix , args... ) where {d} =
252
+ Weighted (ThreadMapCols {d} (f, M. array, args... ), M. weights, M. opt)
253
+
254
+ ThreadMapCols {d} (f:: Function , M:: AbstractMatrix , args... ) where {d} =
255
+ _MapCols (f, M, Val (d), Val (true ), args... )
256
+
213
257
214
258
end # module
0 commit comments