@@ -3,7 +3,7 @@ module SliceMap
3
3
4
4
export mapcols, MapCols, maprows, slicemap, tmapcols, ThreadMapCols
5
5
6
- using MacroTools, Requires, WeightedArrays, TensorCast, JuliennedArrays
6
+ using MacroTools, Requires, TensorCast, JuliennedArrays
7
7
8
8
using Tracker
9
9
using Tracker: TrackedMatrix, track, @grad , data
@@ -25,9 +25,6 @@ They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
25
25
mapcols (f:: Function , M, args... ) = _mapcols (map, f, M, args... )
26
26
tmapcols (f:: Function , M, args... ) = _mapcols (threadmap, f, M, args... )
27
27
28
- _mapcols (map:: Function , f:: Function , M:: WeightedMatrix , args... ) =
29
- Weighted (_mapcols (map, f, M. array, args... ), M. weights, M. opt)
30
-
31
28
_mapcols (map:: Function , f:: Function , M:: AbstractMatrix , args... ) =
32
29
reduce (hcat, map (col -> surevec (f (col, args... )), eachcol (M)))
33
30
83
80
84
81
#= ========= Forward, Static ==========#
85
82
86
- using StaticArrays, ForwardDiff, WeightedArrays
83
+ using StaticArrays, ForwardDiff
87
84
88
85
struct MapCols{d} end
89
86
@@ -95,12 +92,9 @@ Their length `d = size(M,1)` should ideally be provided for type-stability, but
95
92
96
93
The gradient for Tracker and Zygote uses `ForwardDiff` on each slice.
97
94
"""
98
- MapCols (f:: Function , M:: AT , args... ) where {AT <: WeightedArrays.MaybeWeightedMatrix } =
95
+ MapCols (f:: Function , M:: AbstractMatrix , args... ) =
99
96
MapCols {size(M,1)} (f, M, args... )
100
97
101
- MapCols {d} (f:: Function , M:: WeightedMatrix , args... ) where {d} =
102
- Weighted (MapCols {d} (f, M. array, args... ), M. weights, M. opt)
103
-
104
98
MapCols {d} (f:: Function , M:: AbstractMatrix , args... ) where {d} =
105
99
_MapCols (map, f, M, Val (d), args... )
106
100
220
214
# What KissThreading does is much more complicated, perhaps worth investigating:
221
215
# https://github.com/mohamed82008/KissThreading.jl/blob/master/src/KissThreading.jl
222
216
223
- # BTW I do the first one because some diffeq maps are infer to ::Any
217
+ # BTW I do the first one because some diffeq maps infer to ::Any,
224
218
# else you could use Core.Compiler.return_type(f, Tuple{eltype(x)})
225
219
226
220
"""
@@ -260,12 +254,9 @@ struct ThreadMapCols{d} end
260
254
261
255
Like `MapCols` but with multi-threading!
262
256
"""
263
- ThreadMapCols (f:: Function , M:: AT , args... ) where {AT <: WeightedArrays.MaybeWeightedMatrix } =
257
+ ThreadMapCols (f:: Function , M:: AbstractMatrix , args... ) =
264
258
ThreadMapCols {size(M,1)} (f, M, args... )
265
259
266
- ThreadMapCols {d} (f:: Function , M:: WeightedMatrix , args... ) where {d} =
267
- Weighted (ThreadMapCols {d} (f, M. array, args... ), M. weights, M. opt)
268
-
269
260
ThreadMapCols {d} (f:: Function , M:: AbstractMatrix , args... ) where {d} =
270
261
_MapCols (threadmap, f, M, Val (d), args... )
271
262
0 commit comments