Skip to content

Commit b634651

Browse files
author
Michael Abbott
committed
tweaks
1 parent 662ce33 commit b634651

File tree

2 files changed

+29
-24
lines changed

2 files changed

+29
-24
lines changed

docs/intro.md

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,22 @@ which is what `mapcols` does, has some overhead:
3030
using BenchmarkTools
3131
mat1k = rand(3,1000);
3232

33-
@btime mapreduce(fun, hcat, eachcol($mat1k)) # 1.522 ms
34-
@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms
35-
36-
@btime mapcols(fun, $mat1k) # 399.016 μs
37-
@btime MapCols{3}(fun, $mat1k) # 15.564 μs
38-
@btime MapCols(fun, $mat1k) # 16.774 μs without size
39-
40-
@btime ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), $mat1k); # 372.705 ms
41-
@btime Tracker.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 70.203 ms
42-
@btime Tracker.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 146.561 μs, 330.51 KiB
43-
@btime Zygote.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 20.018 ms, 3.82 MiB
44-
@btime Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), $mat1k); # 245.550 μs
33+
@btime mapreduce(fun, hcat, eachcol($mat1k)) # 1.522 ms, 11.80 MiB
34+
@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms, 329.92 KiB
35+
36+
@btime mapcols(fun, $mat1k) # 399.016 μs, 219.02 KiB
37+
@btime MapCols{3}(fun, $mat1k) # 15.564 μs, 47.16 KiB
38+
@btime MapCols(fun, $mat1k) # 16.774 μs (without slice size)
39+
40+
@btime ForwardDiff.gradient(m -> sum(mapslices(fun, m, dims=1)), $mat1k); # 329.305 ms
41+
@btime Tracker.gradient(m -> sum(mapcols(fun, m)), $mat1k); # 70.203 ms
42+
@btime Tracker.gradient(m -> sum(MapCols{3}(fun, m)), $mat1k); # 51.129 μs, 282.92 KiB
43+
@btime Zygote.gradient(m -> sum(mapcols(fun, m)), $mat1k); # 20.454 ms, 3.52 MiB
44+
@btime Zygote.gradient(m -> sum(MapCols{3}(fun, m)), $mat1k); # 28.229 μs, 164.63 KiB
4545
```
4646

47+
For such a simple function, timing `sum(sin, MapCols{3}(fun, m))` takes 3 to 10 times longer!
48+
4749
## Other packages
4850

4951
This package also provides Zygote gradients for the slice/glue functions in
@@ -53,13 +55,13 @@ which can be used to write many mapslices-like operations.
5355

5456
```julia
5557
using TensorCast
56-
@cast [i,j] := fun(mat[:,j])[i] # same as mapcols
58+
@cast [i,j] := fun(mat[:,j])[i] # same as mapcols
5759

5860
tcm(mat) = @cast out[i,j] := fun(mat[:,j])[i]
5961
Zygote.gradient(m -> sum(sin, tcm(m)), mat)[1]
6062

61-
@btime tcm($mat1k) # 407.176 μs
62-
@btime Zygote.gradient(m -> sum(sin, tcm(m)), $mat1k); # 19.086 ms
63+
@btime tcm($mat1k) # 427.907 μs
64+
@btime Zygote.gradient(m -> sum(tcm(m)), $mat1k); # 18.358 ms
6365
```
6466

6567
Similar gradients work for the Slice/Align functions in
@@ -69,11 +71,11 @@ so it defines these too:
6971
```julia
7072
using JuliennedArrays
7173
jumap(f,m) = Align(map(f, Slices(m, True(), False())), True(), False())
72-
jumap(fun, mat) # same as mapcols
74+
jumap(fun, mat) # same as mapcols
7375
Zygote.gradient(m -> sum(sin, jumap(fun, m)), mat)[1]
7476

75-
@btime jumap(fun, $mat1k); # 408.259 μs
76-
@btime Zygote.gradient(m -> sum(sin, jumap(fun, m)), $mat1k); # 18.638 ms
77+
@btime jumap(fun, $mat1k); # 421.061 μs
78+
@btime Zygote.gradient(m -> sum(jumap(fun, m)), $mat1k); # 18.383 ms
7779
```
7880

7981
That's a 2-line gradient definition, so borrowing it may be easier than depending on this package.
@@ -102,11 +104,11 @@ Tracker.gradient(k -> sum(sin, MapCols{2}(g, k, 1:5)), kay)[1]
102104
This is quite efficient, and seems to go well with multi-threading:
103105

104106
```julia
105-
@btime MapCols{2}(g, $kay, 1:5) # 1.423 ms
106-
@btime ThreadMapCols{2}(g, $kay, 1:5) # 713.748 μs
107+
@btime MapCols{2}(g, $kay, 1:5) # 1.394 ms
108+
@btime ThreadMapCols{2}(g, $kay, 1:5) # 697.333 μs
107109

108-
@btime Tracker.gradient(k -> sum(sin, MapCols{2}(g, k, 1:5)), $kay)[1] # 2.535 ms
109-
@btime Tracker.gradient(k -> sum(sin, ThreadMapCols{2}(g, k, 1:5)), $kay)[1] # 1.333 ms
110+
@btime Tracker.gradient(k -> sum(MapCols{2}(g, k, 1:5)), $kay)[1] # 2.561 ms
111+
@btime Tracker.gradient(k -> sum(ThreadMapCols{2}(g, k, 1:5)), $kay)[1] # 1.344 ms
110112

111113
Threads.nthreads() == 4 # on my 2/4-core laptop
112114
```

src/SliceMap.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ _MapCols(map::Function, f::Function, M::TrackedMatrix, dval, args...) =
113113

114114
function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) where {T,d}
115115
d == size(M,1) || error("expected M with $d columns")
116+
k = size(M,2)
117+
116118
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M)))
117119

118120
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
@@ -121,8 +123,9 @@ function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::V
121123
Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C))
122124

123125
function back(ΔZ)
124-
∇M = zeros(eltype(data(ΔZ)), size(M))
125-
@inbounds for c=1:size(M,2)
126+
S = promote_type(T, eltype(data(ΔZ)))
127+
∇M = zeros(S, size(M))
128+
@inbounds for c=1:k
126129
part = ForwardDiff.partials.(C[c])
127130
for r=1:d
128131
for i=1:size(ΔZ,1)

0 commit comments

Comments
 (0)