Skip to content

Commit 662ce33

Browse files
author
Michael Abbott
committed
rm Weighted
1 parent ce17dd5 commit 662ce33

File tree

2 files changed

+5
-15
lines changed

2 files changed

+5
-15
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1111
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1212
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
1313
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
14-
WeightedArrays = "379a43df-f81c-573e-83a6-069eb6c11a71"
1514

1615
[compat]
1716
julia = "1"

src/SliceMap.jl

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module SliceMap
33

44
export mapcols, MapCols, maprows, slicemap, tmapcols, ThreadMapCols
55

6-
using MacroTools, Requires, WeightedArrays, TensorCast, JuliennedArrays
6+
using MacroTools, Requires, TensorCast, JuliennedArrays
77

88
using Tracker
99
using Tracker: TrackedMatrix, track, @grad, data
@@ -25,9 +25,6 @@ They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
2525
mapcols(f::Function, M, args...) = _mapcols(map, f, M, args...)
2626
tmapcols(f::Function, M, args...) = _mapcols(threadmap, f, M, args...)
2727

28-
_mapcols(map::Function, f::Function, M::WeightedMatrix, args...) =
29-
Weighted(_mapcols(map, f, M.array, args...), M.weights, M.opt)
30-
3128
_mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
3229
reduce(hcat, map(col -> surevec(f(col, args...)), eachcol(M)))
3330

@@ -83,7 +80,7 @@ end
8380

8481
#========== Forward, Static ==========#
8582

86-
using StaticArrays, ForwardDiff, WeightedArrays
83+
using StaticArrays, ForwardDiff
8784

8885
struct MapCols{d} end
8986

@@ -95,12 +92,9 @@ Their length `d = size(M,1)` should ideally be provided for type-stability, but
9592
9693
The gradient for Tracker and Zygote uses `ForwardDiff` on each slice.
9794
"""
98-
MapCols(f::Function, M::AT, args...) where {AT<:WeightedArrays.MaybeWeightedMatrix} =
95+
MapCols(f::Function, M::AbstractMatrix, args...) =
9996
MapCols{size(M,1)}(f, M, args...)
10097

101-
MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
102-
Weighted(MapCols{d}(f, M.array, args...), M.weights, M.opt)
103-
10498
MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
10599
_MapCols(map, f, M, Val(d), args...)
106100

@@ -220,7 +214,7 @@ end
220214
# What KissThreading does is much more complicated, perhaps worth investigating:
221215
# https://github.com/mohamed82008/KissThreading.jl/blob/master/src/KissThreading.jl
222216

223-
# BTW I do the first one because some diffeq maps are infer to ::Any
217+
# BTW I do the first one because some diffeq maps infer to ::Any,
224218
# else you could use Core.Compiler.return_type(f, Tuple{eltype(x)})
225219

226220
"""
@@ -260,12 +254,9 @@ struct ThreadMapCols{d} end
260254
261255
Like `MapCols` but with multi-threading!
262256
"""
263-
ThreadMapCols(f::Function, M::AT, args...) where {AT<:WeightedArrays.MaybeWeightedMatrix} =
257+
ThreadMapCols(f::Function, M::AbstractMatrix, args...) =
264258
ThreadMapCols{size(M,1)}(f, M, args...)
265259

266-
ThreadMapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
267-
Weighted(ThreadMapCols{d}(f, M.array, args...), M.weights, M.opt)
268-
269260
ThreadMapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
270261
_MapCols(threadmap, f, M, Val(d), args...)
271262

0 commit comments

Comments
 (0)