Skip to content

Commit 6926925

Browse files
author
Michael Abbott
committed
threads
1 parent 477c388 commit 6926925

File tree

3 files changed

+110
-16
lines changed

3 files changed

+110
-16
lines changed

README.md

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,21 @@ with gradients for [Flux](https://github.com/FluxML/Flux.jl) and [Zygote](https:
77

88
```julia
99
mapcols(f, M) mapreduce(f, hcat, eachcol(M))
10-
MapCols{d}(f, M) # where d=size(M,1), for StaticArrays
10+
MapCols{d}(f, M) # where d=size(M,1), for SVector slices
11+
ThreadMapCols{d}(f, M) # using Threads.@threads
1112

12-
maprows(f, M) mapreduce(f, vcat, eachrow(M))
13+
maprows(f, M) mapslices(f, M, dims=2)
1314

14-
slicemap(f, A; dims) mapslices(f, A, dims)
15+
slicemap(f, A; dims) mapslices(f, A, dims=dims) # only Zygote
1516
```
1617

18+
<!--
19+
It also defines Zygote gradients for the Slice/Align functions in
20+
[JuliennedArrays](https://github.com/bramtayl/JuliennedArrays.jl),
21+
and the slice/glue functions in [TensorCast](https://github.com/mcabbott/TensorCast.jl),
22+
both of which are good ways to roll-your-own `mapslices`-like behaviour.
23+
-->
24+
1725
### Simple example
1826

1927
```julia
@@ -25,7 +33,7 @@ using SliceMap
2533
mapcols(fun, mat) # eachcol(m)
2634
MapCols{3}(fun, mat) # reinterpret(SArray,...)
2735

28-
using Tracker, Zygote, ForwardDiff
36+
using ForwardDiff, Tracker, Zygote
2937
ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
3038

3139
Tracker.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1] # Tracker.forward per slice
@@ -88,20 +96,59 @@ Zygote.gradient(m -> sum(sin, jumap(fun, m)), mat)[1]
8896
@btime Zygote.gradient(m -> sum(sin, jumap(fun, m)), $mat1k); # 18.638 ms
8997
```
9098

99+
That's a 2-line gradient definition, so borrowing it may be easier than depending on this package.
100+
101+
The original purpose of `MapCols`, with ForwardDiff on slices, was that this works well when
102+
the function being mapped integrates some differential equation.
103+
104+
```julia
105+
using DifferentialEquations, ParameterizedFunctions
106+
ode = @ode_def begin
107+
du = ( - k2 * u )/(k1 + u) # an equation with 2 parameters
108+
end k1 k2
109+
110+
function g(k::AbstractVector{T}, times) where T
111+
u0 = T[ 1.0 ] # NB convert initial values to eltype(k)
112+
prob = ODEProblem(ode, u0, (0.0, 0.0+maximum(times)), k)
113+
Array(solve(prob, saveat=times))::Matrix{T}
114+
end
115+
116+
kay = rand(2,50);
117+
MapCols{2}(g, kay, 1:5) # 5 time steps, for each col of parameters
118+
119+
Tracker.gradient(k -> sum(sin, MapCols{2}(g, k, 1:5)), kay)[1]
120+
```
121+
122+
This is both quite efficient, and seems to go well with multi-threading:
123+
124+
```julia
125+
@btime MapCols{2}(g, $kay, 1:5) # 1.369 ms
126+
@btime ThreadMapCols{2}(g, $kay, 1:5) # 670.384 μs
127+
128+
@btime Tracker.gradient(k -> sum(sin, MapCols{2}(g, k, 1:5)), $kay)[1] # 2.438 ms
129+
@btime Tracker.gradient(k -> sum(sin, ThreadMapCols{2}(g, k, 1:5)), $kay)[1] # 1.229 ms
130+
131+
Threads.nthreads() == 4
132+
```
133+
91134
### Elsewhere
92135

93-
About mapslices:
136+
Issues about mapslices:
94137
* https://github.com/FluxML/Zygote.jl/issues/92
95138
* https://github.com/FluxML/Flux.jl/issues/741
96139
* https://github.com/JuliaLang/julia/issues/29146
97140

141+
Differential equations:
142+
* https://arxiv.org/abs/1812.01892 "DSAAD"
143+
* http://docs.juliadiffeq.org/latest/analysis/sensitivity.html
144+
98145
Other packages which define gradients of possible interest:
99146
* https://github.com/GiggleLiu/LinalgBackwards.jl
100147
* https://github.com/mcabbott/ArrayAllez.jl
101148

102-
AD packages this could perhaps support, quite the zoo:
103-
* https://github.com/invenia/Nabla.jl
149+
Differentiation packages this could perhaps support, quite the zoo:
104150
* https://github.com/dfdx/Yota.jl
151+
* https://github.com/invenia/Nabla.jl
105152
* https://github.com/denizyuret/AutoGrad.jl
106153
* https://github.com/Roger-luo/YAAD.jl
107154
* And perhaps one day, just https://github.com/JuliaDiff/ChainRules.jl

src/SliceMap.jl

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
module SliceMap
33

4-
export mapcols, MapCols, maprows, slicemap
4+
export mapcols, MapCols, maprows, slicemap, ThreadMapCols
55

66
using MacroTools, Requires, WeightedArrays, TensorCast, JuliennedArrays
77

@@ -98,25 +98,29 @@ MapCols(f::Function, M::AT, args...) where {AT<:WeightedArrays.MaybeWeightedMatr
9898
MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
9999
Weighted(MapCols{d}(f, M.array, args...), M.weights, M.opt)
100100

101-
MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} = _MapCols(f, M, Val(d), args...)
101+
MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
102+
_MapCols(f, M, Val(d), Val(false), args...)
102103

103-
function _MapCols(f::Function, M::Matrix{T}, ::Val{d}, args...) where {T,d}
104+
function _MapCols(f::Function, M::Matrix{T}, ::Val{d}, tval::Val, args...) where {T,d}
104105
d == size(M,1) || error("expected M with $d columns")
105106
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(M))
106-
B = map(col -> surevec(f(col, args...)), A)
107+
B = maybethreadmap(col -> surevec(f(col, args...)), A, tval)
107108
reduce(hcat, B)
108109
end
109110

110-
_MapCols(f::Function, M::TrackedMatrix, dval, args...) = track(_MapCols, f, M, dval, args...)
111+
_MapCols(f::Function, M::TrackedMatrix, dval, tval, args...) =
112+
track(_MapCols, f, M, dval, tval, args...)
111113

112-
@grad _MapCols(f::Function, M::TrackedMatrix, dval, args...) = ∇MapCols(f, M, dval, args...)
114+
@grad _MapCols(f::Function, M::TrackedMatrix, dval, tval, args...) =
115+
∇MapCols(f, M, dval, tval, args...)
116+
117+
function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, tval::Val, args...) where {T,d}
113118

114-
function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) where {T,d}
115119
d == size(M,1) || error("expected M with $d columns")
116120
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M)))
117121

118122
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
119-
C = map(col -> surevec(f(col + dualcol, args...)), A)
123+
C = maybethreadmap(col -> surevec(f(col + dualcol, args...)), A, tval)
120124

121125
Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C))
122126

@@ -130,7 +134,7 @@ function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) wh
130134
end
131135
end
132136
end
133-
(nothing, ∇M, nothing, map(_->nothing, args)...)
137+
(nothing, ∇M, nothing, nothing, map(_->nothing, args)...)
134138
end
135139
Z, back
136140
end
@@ -210,5 +214,45 @@ end
210214
# Following a suggestion? Doesn't help.
211215
# @adjoint Base.collect(x) = collect(x), Δ -> (Δ,)
212216

217+
#========== Threaded Map ==========#
218+
219+
# What KissThreading does is much more complicated, perhaps worth investigating:
220+
# https://github.com/mohamed82008/KissThreading.jl/blob/master/src/KissThreading.jl
221+
222+
function threadmap(f::Function, v::AbstractVector)
223+
length(v)==0 && error("can't map over empty vector, sorry")
224+
out1 = f(first(v))
225+
_threadmap(out1, f, v)
226+
end
227+
# NB barrier
228+
function _threadmap(out1, f, v)
229+
out = Vector{typeof(out1)}(undef, length(v))
230+
out[1] = out1
231+
Threads.@threads for i=2:length(v)
232+
@inbounds out[i] = f(v[i])
233+
end
234+
out
235+
end
236+
237+
# This switch is fast inside ∇MapCols, after many attempts!
238+
maybethreadmap(f, v, ::Val{true}) = threadmap(f, v)
239+
maybethreadmap(f, v, ::Val{false}) = map(f, v)
240+
241+
struct ThreadMapCols{d} end
242+
243+
"""
244+
ThreadMapCols{d}(f, m::Matrix, args...)
245+
246+
Like `MapCols` but with multi-threading!
247+
"""
248+
ThreadMapCols(f::Function, M::AT, args...) where {AT<:WeightedArrays.MaybeWeightedMatrix} =
249+
ThreadMapCols{size(M,1)}(f, M, args...)
250+
251+
ThreadMapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
252+
Weighted(ThreadMapCols{d}(f, M.array, args...), M.weights, M.opt)
253+
254+
ThreadMapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
255+
_MapCols(f, M, Val(d), Val(true), args...)
256+
213257

214258
end # module

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@ Zygote.refresh()
1313

1414
@test res mapcols(fun, mat)
1515
@test res MapCols{3}(fun, mat)
16+
@test res ThreadMapCols{3}(fun, mat)
1617

1718
grad = ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), mat)
1819

1920
@test grad Tracker.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
2021
@test grad Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
22+
@test grad Tracker.gradient(m -> sum(sin, ThreadMapCols{3}(fun, m)), mat)[1]
2123

2224
@test grad Zygote.gradient(m -> sum(sin, mapcols(fun, m)), mat)[1]
2325
@test grad Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
26+
@test grad Zygote.gradient(m -> sum(sin, ThreadMapCols{3}(fun, m)), mat)[1]
2427

2528
tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i]
2629
@test res tcm(mat)

0 commit comments

Comments
 (0)