Skip to content

Commit abce623

Browse files
author
Michael Abbott
committed
tidy, add slicemap
1 parent 002092a commit abce623

File tree

2 files changed

+63
-86
lines changed

2 files changed

+63
-86
lines changed

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ fun(x) = 2 .+ x.^2
99
mapslices(fun, mat, dims=1)
1010

1111
using SliceMap
12-
1312
mapcols(fun, mat) # eachcol(m)
1413
MapCols{3}(fun, mat) # reinterpret(SArray,...)
1514

@@ -27,12 +26,15 @@ Zygote.gradient(m -> sum(sin, MapCols{3}(fun, m)), mat)[1]
2726
These are a bit faster than `mapslices` too:
2827

2928
```julia
29+
using BenchmarkTools
3030
mat1k = rand(3,1000);
3131

32-
@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms
33-
@btime mapcols(fun, $mat1k) # 399.016 μs
34-
@btime MapCols{3}(fun, $mat1k) # 15.564 μs
35-
@btime MapCols(fun, $mat1k) # 16.774 μs without size
32+
@btime mapreduce(fun, hcat, eachcol($mat1k)) # 1.522 ms
33+
@btime mapslices(fun, $mat1k, dims=1) # 1.017 ms
34+
35+
@btime mapcols(fun, $mat1k) # 399.016 μs
36+
@btime MapCols{3}(fun, $mat1k) # 15.564 μs
37+
@btime MapCols(fun, $mat1k) # 16.774 μs without size
3638

3739
@btime ForwardDiff.gradient(m -> sum(sin, mapslices(fun, m, dims=1)), $mat1k); # 372.705 ms
3840
@btime Tracker.gradient(m -> sum(sin, mapcols(fun, m)), $mat1k); # 70.203 ms

src/SliceMap.jl

Lines changed: 56 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,25 @@
11

22
module SliceMap
33

4-
export MapCols, mapcols, maprows
4+
export mapcols, MapCols, maprows, slicemap
55

6-
7-
#========== Gradient Macro ==========#
8-
9-
using MacroTools, Tracker, Zygote
6+
using MacroTools, Tracker, Zygote, WeightedArrays
107
using Tracker: TrackedMatrix, track, @grad, data
118
using Zygote: @adjoint, _zero
129

13-
macro gradadjoint(ex)
14-
quote
15-
# $(Zygote.gradm(ex)) # this doesn't work
16-
$(trackergrad(ex))
17-
end
18-
end
19-
20-
# Copied from https://github.com/FluxML/Tracker.jl/blob/master/src/Tracker.jl#L55
21-
function trackergrad(ex)
22-
@capture(shortdef(ex), (name_(args__) = body_) |
23-
(name_(args__) where {T__} = body_)) || error("Need a function definition")
24-
T == nothing && (T = [])
25-
isexpr(name, :(::)) || (name = :(::typeof($name)))
26-
insert!(args, 1+isexpr(args[1], :parameters) , name)
27-
MacroTools.@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
28-
end
29-
30-
3110
#========== Reverse, Eachslice ==========#
3211

33-
using WeightedArrays
34-
3512
"""
36-
mapcols(f, m::Matrix, args...) = reduce(hcat, f(c, args...) for c in eachcol(M))
13+
mapcols(f, m) ≈ mapreduce(f, hcat, eachcol(m)) ≈ mapslices(f, m, dims=1)
3714
38-
When `m::TrackedMatrix`, it saves the backward function for each slice.
39-
All further arguments are scalar constants, i.e. they do not get sliced/iterated (unlike `map`)
40-
nor are their gradients tracked.
15+
This is a more efficient version of the functions on the right.
16+
For `f(x::Vector)::Matrix` it reshapes like `mapslices(vec∘f, m, dims=1)`.
17+
18+
It provides a gradient for Tracker and Zygote, saving the backward function for each slice.
19+
20+
Any arguments after the matrix are passed to `f` as scalars, i.e.
21+
`mapcols(f, m, args...) = reduce(hcat, f(col, args...) for col in eeachcol(m))`.
22+
They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
4123
"""
4224
mapcols(f::Function, M::AbstractMatrix, args...) =
4325
reduce(hcat, [ surevec(f(col, args...)) for col in eachcol(M) ])
@@ -50,44 +32,42 @@ surevec(A) = vec(A) # to allow f vector -> matrix, by reshaping
5032

5133
mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
5234

53-
@grad function mapcols(f::Function, M::AbstractMatrix, args...)
54-
res = [ Tracker.forward(x -> surevec(f(x, args...)), col) for col in eachcol(data(M)) ]
55-
fwd = reduce(hcat, data.(first.(res)))
56-
function back(Δ)
57-
cols = [ data((last(res[c]))(Δcol)[1]) for (c, Δcol) in enumerate(eachcol(data(Δ))) ]
58-
∇M = reduce(hcat, cols)
59-
(nothing, ∇M, map(_->nothing, args)...)
60-
end
61-
fwd, back
62-
end
35+
@grad mapcols(f::Function, M::AbstractMatrix, args...) =
36+
∇mapcols([ Tracker.forward(x -> surevec(f(x, args...)), col) for col in eachcol(data(M)) ], args)
6337

64-
@adjoint function mapcols(f::Function, M::Matrix, args...)
65-
res = [ Zygote.forward(x -> surevec(f(x, args...)), col) for col in eachcol(M) ]
66-
fwd = reduce(hcat, first.(res))
67-
function back)
68-
cols = [ (last(res[c]))(Δcol)[1] for (c, Δcol) in enumerate(eachcol(Δ)) ]
69-
∇M = reduce(hcat, cols)
70-
(nothing, ∇M, map(_->nothing, args)...)
38+
@adjoint mapcols(f::Function, M::AbstractMatrix, args...) =
39+
∇mapcols([ Zygote.forward(x -> surevec(f(x, args...)), col) for col in eachcol(M) ], args)
40+
41+
function ∇mapcols(forwards, args)
42+
reduce(hcat, data.(first.(forwards))), Δ -> begin
43+
cols = [ data(last(fwd)(Δcol)[1]) for (fwd, Δcol) in zip(forwards, eachcol(data(Δ))) ]
44+
(nothing, reduce(hcat, cols), map(_->nothing, args)...)
7145
end
72-
fwd, back
7346
end
7447

48+
"""
49+
maprows(f, M) ≈ mapslices(f, M, dims=2)
50+
51+
Like `mapcols()` but for rows.
52+
"""
7553
maprows(f::Function, M::AbstractMatrix, args...) =
7654
reduce(vcat, [ surerow(f(col, args...)) for col in eachrow(M) ])
7755

7856
surerow(x) = transpose(surevec(x))
7957

8058
maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
8159

82-
@grad function maprows(f::Function, M::AbstractMatrix, args...)
83-
res = [ Tracker.forward(x -> surerow(f(x, args...)), row) for row in eachrow(data(M)) ]
84-
fwd = reduce(vcat, data.(first.(res)))
85-
function back(Δ)
86-
rows = [ data((last(res[r]))(Δrow)[1]) for (r, Δrow) in enumerate(eachrow(data(Δ))) ]
87-
∇M = reduce(vcat, rows)
88-
(nothing, ∇M, map(_->nothing, args)...)
60+
@grad maprows(f::Function, M::AbstractMatrix, args...) =
61+
∇maprows([ Tracker.forward(x -> surerow(f(x, args...)), row) for row in eachrow(data(M)) ], args)
62+
63+
@adjoint maprows(f::Function, M::AbstractMatrix, args...) =
64+
∇maprows([ Zygote.forward(x -> surerow(f(x, args...)), row) for row in eachrow(M) ], args)
65+
66+
function ∇maprows(forwards, args)
67+
reduce(vcat, data.(first.(forwards))), Δ -> begin
68+
rows = [ data(last(fwd)(Δrow)[1]) for (fwd, Δrow) in zip(forwards, eachrow(data(Δ))) ]
69+
(nothing, reduce(vcat, rows), map(_->nothing, args)...)
8970
end
90-
fwd, back
9171
end
9272

9373

@@ -100,16 +80,12 @@ struct MapCols{d} end
10080
"""
10181
MapCols{d}(f, m::Matrix, args...)
10282
103-
Expects `f(::SVector{d}, args...)` and maps this over the columns, `d = size(M,1)`.
104-
Doesn't expect `f` to return a staticarray, just an array.
105-
106-
When `m::TrackedMatrix`, it uses `ForwardDiff` to calculate the gradient of each slice.
107-
The second point of keeping one type parameter is that the dual numbers needed depend on this.
83+
Similar to `mapcols(f, m, args...)`, but slices `m` into `SVector{d}` columns.
84+
Their length `d = size(M,1)` should ideally be provided for type-stability, but is not required.
10885
109-
MapCols{d}(f, m::Weighted, args...)
110-
Takes `m.weights` along for the ride.
86+
The gradient for Tracker and Zygote uses `ForwardDiff` on each slice.
11187
"""
112-
MapCols(f::Function, M::WeightedArrays.MaybeWeightedMatrix, args...) =
88+
MapCols(f::Function, M::AT, args...) where {AT<:WeightedArrays.MaybeWeightedMatrix} =
11389
MapCols{size(M,1)}(f, M, args...)
11490

11591
MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} =
@@ -119,14 +95,15 @@ MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} = _MapCols(f, M, V
11995

12096
function _MapCols(f::Function, M::Matrix{T}, ::Val{d}, args...) where {T,d}
12197
d == size(M,1) || error("expected M with $d columns")
122-
# @cast A[c]{r:d} := M[r,c] assert
12398
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(M))
12499
B = map(col -> surevec(f(col, args...)), A)
125100
reduce(hcat, B)
126101
# maybestaticgluecols(B)
127102
end
128103

129104
# surevec(x::MArray) = Array(x) # avoid making a huge MArray, ad
105+
# surevecS(x::Number) = @SVector [x]
106+
# surevecS(A) = vec(A) # like surevec
130107

131108
function maybestaticgluecols(B)
132109
TB = eltype(B)
@@ -139,50 +116,33 @@ function maybestaticgluecols(B)
139116
end
140117
end
141118

142-
# surevecS(x::Number) = @SVector [x]
143-
# surevecS(A) = vec(A) # like surevec
144-
145119
_MapCols(f::Function, M::TrackedMatrix, dval, args...) = track(_MapCols, f, M, dval, args...)
146120

147121
@grad _MapCols(f::Function, M::TrackedMatrix, dval, args...) = ∇MapCols(f, M, dval, args...)
148122

149123
@adjoint _MapCols(f::Function, M::Matrix, dval, args...) = ∇MapCols(f, M, dval, args...)
150124

151125
function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) where {T,d}
152-
153126
d == size(M,1) || error("expected M with $d columns")
154-
# @cast A[c]{r:d} := data(M)[r,c]
155127
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M)))
156128

157129
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
158-
159-
# C = [ surevec(f(col .+ dualcol, args...)) for col in A ]
160130
C = map(col -> surevec(f(col .+ dualcol, args...)), A)
161131

162-
# Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ])
163132
Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C))
164133

165134
function back(ΔZ)
166-
# accum = zero(eltype(data(ΔZ)))
167-
# ∇M = similar(data(M)) .+ zero(first(data(ΔZ)))
168135
∇M = zeros(eltype(data(ΔZ)), size(M))
169136
@inbounds for c=1:size(M,2)
170137
part = ForwardDiff.partials.(C[c])
171138
for r=1:d
172-
# ∇M[r,c] = 0
173-
# accum = 0
174139
for i=1:size(ΔZ,1)
175140
∇M[r,c] += data(ΔZ)[i,c] * part[i].values[r]
176-
# parti = ForwardDiff.partials(C[c][i])
177-
# ∇M[r,c] += data(ΔZ)[i,c] * parti.values[r]
178-
# accum += data(ΔZ)[i,c] * part[i].values[r]
179141
end
180-
# ∇M[r,c] = accum
181142
end
182143
end
183144
(nothing, ∇M, nothing, map(_->nothing, args)...)
184145
end
185-
186146
Z, back
187147
end
188148

@@ -298,5 +258,20 @@ end
298258
TensorCast.red_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
299259
end
300260

261+
@adjoint function TensorCast.copy_glue(A::AbstractArray, code::Tuple)
262+
TensorCast.copy_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
263+
end
264+
265+
"""
266+
slicemap(f, A; dims) ≈ mapslices(f, A; dims)
267+
268+
Like `mapcols()`, but for any slice. Gradient is for Zygote only.
269+
"""
270+
function slicemap(f::Function, A::AbstractArray{T,N}, args...; dims) where {T,N}
271+
code = ntuple(d -> d in dims ? (:) : (*), N)
272+
B = TensorCast.sliceview(A, code)
273+
C = [ f(slice, args...) for slice in B ]
274+
TensorCast.glue(C, code)
275+
end
301276

302277
end # module

0 commit comments

Comments
 (0)