Skip to content

Commit d328e5e

Browse files
author
Michael Abbott
committed
maps + threads
1 parent 6926925 commit d328e5e

File tree

3 files changed

+76
-71
lines changed

3 files changed

+76
-71
lines changed

src/SliceMap.jl

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
module SliceMap
33

4-
export mapcols, MapCols, maprows, slicemap, ThreadMapCols
4+
export mapcols, MapCols, maprows, slicemap, tmapcols, ThreadMapCols
55

66
using MacroTools, Requires, WeightedArrays, TensorCast, JuliennedArrays
77

@@ -22,24 +22,27 @@ Any arguments after the matrix are passed to `f` as scalars, i.e.
2222
`mapcols(f, m, args...) = reduce(hcat, f(col, args...) for col in eeachcol(m))`.
2323
They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
2424
"""
25-
mapcols(f::Function, M::AbstractMatrix, args...) =
26-
reduce(hcat, [ surevec(f(col, args...)) for col in eachcol(M) ])
25+
mapcols(f::Function, M, args...) = _mapcols(map, f, M, args...)
26+
tmapcols(f::Function, M, args...) = _mapcols(threadmap, f, M, args...)
2727

28-
mapcols(f::Function, M::WeightedMatrix, args...) =
29-
Weighted(mapcols(f, M.array, args...), M.weights, M.opt)
28+
_mapcols(map::Function, f::Function, M::WeightedMatrix, args...) =
29+
Weighted(_mapcols(map, f, M.array, args...), M.weights, M.opt)
30+
31+
_mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
32+
reduce(hcat, map(col -> surevec(f(col, args...)), eachcol(M)))
3033

3134
surevec(x::Number) = [x] # to allow f vector -> scalar, as mapslices does
3235
surevec(A) = vec(A) # to allow f vector -> matrix, by reshaping
3336

34-
mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
37+
_mapcols(map::Function, f::Function, M::TrackedMatrix, args...) = track(_mapcols, map, f, M, args...)
3538

36-
@grad mapcols(f::Function, M::AbstractMatrix, args...) =
37-
∇mapcols([ Tracker.forward(x -> surevec(f(x, args...)), col) for col in eachcol(data(M)) ], args)
39+
@grad _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
40+
∇mapcols(map, map(col -> Tracker.forward(x -> surevec(f(x, args...)), col), eachcol(data(M))), args)
3841

39-
function ∇mapcols(forwards, args)
40-
reduce(hcat, data.(first.(forwards))), Δ -> begin
41-
cols = [ data(last(fwd)(Δcol)[1]) for (fwd, Δcol) in zip(forwards, eachcol(data(Δ))) ]
42-
(nothing, reduce(hcat, cols), map(_->nothing, args)...)
42+
function ∇mapcols(bigmap, forwards, args)
43+
reduce(hcat, map(datafirst, forwards)), Δ -> begin
44+
cols = bigmap((fwd, Δcol) -> data(last(fwd)(Δcol)[1]), forwards, eachcol(data(Δ)))
45+
(nothing, nothing, reduce(hcat, cols), map(_->nothing, args)...)
4346
end
4447
end
4548

@@ -49,16 +52,16 @@ end
4952
Like `mapcols()` but for rows.
5053
"""
5154
maprows(f::Function, M::AbstractMatrix, args...) =
52-
reduce(vcat, [ transpose(surevec(f(col, args...))) for col in eachrow(M) ])
55+
reduce(vcat, map(col -> transpose(surevec(f(col, args...))), eachrow(M)))
5356

5457
maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
5558

5659
@grad maprows(f::Function, M::AbstractMatrix, args...) =
57-
∇maprows([ Tracker.forward(x -> surevec(f(x, args...)), row) for row in eachrow(data(M)) ], args)
60+
∇maprows(map(row -> Tracker.forward(x -> surevec(f(x, args...)), row), eachrow(data(M))), args)
5861

5962
function ∇maprows(forwards, args)
6063
reduce(vcat, map(transposedatafirst, forwards)), Δ -> begin
61-
rows = [ data(last(fwd)(Δrow)[1]) for (fwd, Δrow) in zip(forwards, eachrow(data(Δ))) ]
64+
rows = map((fwd, Δrow) -> data(last(fwd)(Δrow)[1]), forwards, eachrow(data(Δ)))
6265
(nothing, reduce(vcat, transpose.(rows)), map(_->nothing, args)...)
6366
end
6467
end
@@ -67,7 +70,7 @@ end
6770
slicemap(f, A; dims) ≈ mapslices(f, A; dims)
6871
6972
Like `mapcols()`, but for any slice. The function `f` must preserve shape,
70-
e.g. `dims=(2,4)` then `f` must map matrices to matrices.
73+
e.g. if `dims=(2,4)` then `f` must map matrices to matrices.
7174
7275
The gradient is for Zygote only.
7376
"""
@@ -99,28 +102,27 @@ MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
99102
Weighted(MapCols{d}(f, M.array, args...), M.weights, M.opt)
100103

101104
MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
102-
_MapCols(f, M, Val(d), Val(false), args...)
105+
_MapCols(map, f, M, Val(d), args...)
103106

104-
function _MapCols(f::Function, M::Matrix{T}, ::Val{d}, tval::Val, args...) where {T,d}
107+
function _MapCols(map::Function, f::Function, M::Matrix{T}, ::Val{d}, args...) where {T,d}
105108
d == size(M,1) || error("expected M with $d columns")
106109
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(M))
107-
B = maybethreadmap(col -> surevec(f(col, args...)), A, tval)
110+
B = map(col -> surevec(f(col, args...)), A)
108111
reduce(hcat, B)
109112
end
110113

111-
_MapCols(f::Function, M::TrackedMatrix, dval, tval, args...) =
112-
track(_MapCols, f, M, dval, tval, args...)
113-
114-
@grad _MapCols(f::Function, M::TrackedMatrix, dval, tval, args...) =
115-
∇MapCols(f, M, dval, tval, args...)
114+
_MapCols(map::Function, f::Function, M::TrackedMatrix, dval, args...) =
115+
track(_MapCols, map, f, M, dval, args...)
116116

117-
function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, tval::Val, args...) where {T,d}
117+
@grad _MapCols(map::Function, f::Function, M::TrackedMatrix, dval, args...) =
118+
∇MapCols(map, f, M, dval, args...)
118119

120+
function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) where {T,d}
119121
d == size(M,1) || error("expected M with $d columns")
120122
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M)))
121123

122124
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
123-
C = maybethreadmap(col -> surevec(f(col + dualcol, args...)), A, tval)
125+
C = bigmap(col -> surevec(f(col + dualcol, args...)), A)
124126

125127
Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C))
126128

@@ -134,15 +136,14 @@ function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, tval::Val,
134136
end
135137
end
136138
end
137-
(nothing, ∇M, nothing, nothing, map(_->nothing, args)...)
139+
(nothing, nothing, ∇M, nothing, map(_->nothing, args)...)
138140
end
139141
Z, back
140142
end
141143

142144
#========== Gradients for Zygote ==========#
143145

144-
# @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
145-
# end
146+
# @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
146147

147148
@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("zygote.jl")
148149

@@ -219,24 +220,35 @@ end
219220
# What KissThreading does is much more complicated, perhaps worth investigating:
220221
# https://github.com/mohamed82008/KissThreading.jl/blob/master/src/KissThreading.jl
221222

222-
function threadmap(f::Function, v::AbstractVector)
223-
length(v)==0 && error("can't map over empty vector, sorry")
224-
out1 = f(first(v))
225-
_threadmap(out1, f, v)
223+
"""
224+
threadmap(f, A)
225+
threadmap(f, A, B)
226+
227+
Simple version of `map` using a `Threads.@threads` loop;
228+
only for vectors & only two of them, of nonzero length,
229+
with all outputs having the same type.
230+
"""
231+
function threadmap(f::Function, vw::AbstractVector...)
232+
length(first(vw))==0 && error("can't map over empty vector, sorry")
233+
length(vw)==2 && (isequal(length.(vw)...) || error("lengths must be equal"))
234+
out1 = f(first.(vw)...)
235+
_threadmap(out1, f, vw...)
226236
end
227237
# NB barrier
228-
function _threadmap(out1, f, v)
229-
out = Vector{typeof(out1)}(undef, length(v))
238+
function _threadmap(out1, f, vw...)
239+
out = Vector{typeof(out1)}(undef, length(first(vw)))
230240
out[1] = out1
231-
Threads.@threads for i=2:length(v)
232-
@inbounds out[i] = f(v[i])
241+
Threads.@threads for i=2:length(first(vw))
242+
@inbounds out[i] = f(getindex.(vw, i)...)
233243
end
234244
out
235245
end
236246

237-
# This switch is fast inside ∇MapCols, after many attempts!
238-
maybethreadmap(f, v, ::Val{true}) = threadmap(f, v)
239-
maybethreadmap(f, v, ::Val{false}) = map(f, v)
247+
# Collect generators to allow indexing
248+
threadmap(f::Function, v) = threadmap(f, collect(v))
249+
threadmap(f::Function, v, w) = threadmap(f, collect(v), collect(w))
250+
threadmap(f::Function, v, w::AbstractVector) = threadmap(f, collect(v), w)
251+
threadmap(f::Function, v::AbstractVector, w) = threadmap(f, v, collect(w))
240252

241253
struct ThreadMapCols{d} end
242254

@@ -252,7 +264,7 @@ ThreadMapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
252264
Weighted(ThreadMapCols{d}(f, M.array, args...), M.weights, M.opt)
253265

254266
ThreadMapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
255-
_MapCols(f, M, Val(d), Val(true), args...)
267+
_MapCols(threadmap, f, M, Val(d), args...)
256268

257269

258270
end # module

src/zygote.jl

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,56 +4,40 @@ using .Zygote: @adjoint, _zero, forward
44

55
#===== mapcols, maprows, MapCols =====#
66

7-
@adjoint mapcols(f::Function, M::AbstractMatrix, args...) =
8-
∇mapcols([ forward(x -> surevec(f(x, args...)), col) for col in eachcol(M) ], args)
7+
@adjoint _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
8+
∇mapcols(map, map(col -> forward(x -> surevec(f(x, args...)), col), eachcol(M)), args)
99

1010
@adjoint maprows(f::Function, M::AbstractMatrix, args...) =
11-
∇maprows([ forward(x -> surevec(f(x, args...)), row) for row in eachrow(M) ], args)
11+
∇maprows(map(row -> forward(x -> surevec(f(x, args...)), row), eachrow(M)), args)
1212

13-
@adjoint _MapCols(f::Function, M::Matrix, dval, args...) = ∇MapCols(f, M, dval, args...)
13+
@adjoint _MapCols(map::Function, f::Function, M::Matrix, dval, args...) =
14+
∇MapCols(map, f, M, dval, args...)
1415

1516
#===== TensorCast =====#
1617

17-
# @adjoint function TensorCast.sliceview(A::AbstractArray, code::Tuple)
18-
# TensorCast.sliceview(A, code), Δ -> begin
19-
# dA = _zero(A)
20-
# foreach(copyto!, TensorCast.sliceview(dA, code), Δ)
21-
# (dA, nothing)
22-
# end
23-
# end
24-
25-
@adjoint function TensorCast.sliceview(A::AbstractArray, code::Tuple)
18+
@adjoint TensorCast.sliceview(A::AbstractArray, code::Tuple) =
2619
TensorCast.sliceview(A, code), Δ -> (TensorCast.glue(Δ, code), nothing)
27-
end
2820

29-
@adjoint function TensorCast.red_glue(A::AbstractArray, code::Tuple)
21+
@adjoint TensorCast.red_glue(A::AbstractArray, code::Tuple) =
3022
TensorCast.red_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
31-
end
3223

33-
@adjoint function TensorCast.copy_glue(A::AbstractArray, code::Tuple)
24+
@adjoint TensorCast.copy_glue(A::AbstractArray, code::Tuple) =
3425
TensorCast.copy_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
35-
end
3626

3727
#===== JuliennedArrays =====#
3828

39-
@adjoint function Slices(whole, along...)
29+
@adjoint JuliennedArrays.Slices(whole, along...) =
4030
Slices(whole, along...), Δ -> (Align(Δ, along...), map(_->nothing, along)...)
41-
end
4231

43-
@adjoint function Align(whole, along...)
32+
@adjoint JuliennedArrays.Align(whole, along...) =
4433
Align(whole, along...), Δ -> (Slices(Δ, along...), map(_->nothing, along)...)
45-
end
4634

4735
#===== Misc Base =====#
4836

49-
@adjoint function Base.reduce(::typeof(hcat), V::AbstractVector{<:AbstractVector})
37+
@adjoint Base.reduce(::typeof(hcat), V::AbstractVector{<:AbstractVector}) =
5038
reduce(hcat, V), dV -> (nothing, collect(eachcol(dV)),)
51-
end
52-
53-
# Zygote doesn't understand views, but easy to fix:
54-
# https://github.com/FluxML/Zygote.jl/issues/52
55-
# now https://github.com/FluxML/Zygote.jl/pull/219
5639

40+
# https://github.com/FluxML/Zygote.jl/pull/219
5741
@adjoint function view(x::AbstractArray, inds...; kwargs...)
5842
view(x, inds...; kwargs...), dy -> begin
5943
dx = _zero(x)
@@ -64,9 +48,8 @@ end
6448

6549
#===== Misc experiments =====#
6650

67-
@adjoint function gluecol(V::AbstractVector)
51+
@adjoint gluecol(V::AbstractVector) =
6852
gluecol(V), ΔM -> (collect(eachcol(ΔM)),) # does work!
69-
end
7053

7154
@adjoint function collecteachcol(x)
7255
collecteachcol(x), dy -> begin

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,26 @@ Zygote.refresh()
1313

1414
@test res mapcols(fun, mat)
1515
@test res MapCols{3}(fun, mat)
16+
@test res MapCols(fun, mat)
17+
18+
@test res tmapcols(fun, mat)
1619
@test res ThreadMapCols{3}(fun, mat)
20+
@test res ThreadMapCols(fun, mat)
1721

1822
grad = ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
1923

2024
@test grad Tracker.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
2125
@test grad Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
26+
@test grad Tracker.gradient(m -> sum(sin, MapCols(fun, m)), mat)[1]
27+
28+
@test grad Tracker.gradient(m -> sum(sin, tmapcols(fun, m)), mat)[1]
2229
@test grad Tracker.gradient(m -> sum(sin, ThreadMapCols{3}(fun, m)), mat)[1]
30+
@test grad Tracker.gradient(m -> sum(sin, ThreadMapCols(fun, m)), mat)[1]
2331

2432
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
2533
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
34+
35+
@test grad Zygote.gradient(m -> sum(sin, tmapcols(fun, m)), mat)[1]
2636
@test grad Zygote.gradient(m -> sum(sin, ThreadMapCols{3}(fun, m)), mat)[1]
2737

2838
tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i]

0 commit comments

Comments
 (0)