Skip to content

Commit 102df32

Browse files
author
Michael Abbott
committed
better treatment for scalar functions
1 parent 6457440 commit 102df32

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

src/SliceMap.jl

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using ZygoteRules: pullback, @adjoint
1818
1919
This is a more efficient version of the functions on the right.
2020
For `f(x::Vector)::Matrix` it reshapes like `mapslices(vec∘f, m, dims=1)`.
21+
For `f(x::Vector)::Number` it skips the reduction, just `reshape(map(f, eachcol(m)),1,:)`.
2122
2223
It provides a gradient for Tracker and Zygote, saving the backward function for each slice.
2324
@@ -28,53 +29,58 @@ They do not get sliced/iterated (unlike `map`), nor are their gradients tracked.
2829
mapcols(f::Function, M, args...) = _mapcols(map, f, M, args...)
2930
tmapcols(f::Function, M, args...) = _mapcols(threadmap, f, M, args...)
3031

31-
_mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
32-
reduce(hcat, map(col -> surevec(f(col, args...), M), eachcol(M)))
32+
function _mapcols(map::Function, f::Function, M::AbstractMatrix, args...)
33+
res = map(col -> _vec(f(col, args...)), eachcol(M))
34+
eltype(res) <: AbstractVector ? reduce(hcat, res) : reshape(res,1,:)
35+
end
36+
37+
_vec(x) = x
38+
_vec(A::AbstractArray) = vec(A) # to allow f vector -> matrix, by reshaping
3339

3440
_mapcols(map::Function, f::Function, M::TrackedMatrix, args...) = track(_mapcols, map, f, M, args...)
3541

3642
@grad _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
37-
∇mapcols(map, map(col -> Tracker.forward(x -> surevec(f(x, args...), M), col), eachcol(data(M))), args...)
43+
∇mapcols(map, map(col -> Tracker.forward(x -> _vec(f(x, args...)), col), eachcol(data(M))), args...)
3844

3945
@adjoint _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
40-
∇mapcols(map, map(col -> ZygoteRules.pullback(x -> surevec(f(x, args...), M), col), eachcol(M)), args)
46+
∇mapcols(map, map(col -> ZygoteRules.pullback(x -> _vec(f(x, args...)), col), eachcol(M)), args)
4147

4248
function ∇mapcols(bigmap, forwards, args...)
43-
reduce(hcat, map(datafirst, forwards)), Δ -> begin
44-
cols = bigmap((fwd, Δcol) -> data(last(fwd)(Δcol)[1]), forwards, eachcol(data(Δ)))
49+
res = map(datafirst, forwards)
50+
function back(Δ)
51+
Δcols = eltype(res) <: AbstractVector ? eachcol(data(Δ)) : vec(data(Δ))
52+
cols = bigmap((fwd, Δcol) -> data(last(fwd)(Δcol)[1]), forwards, Δcols)
4553
(nothing, nothing, reduce(hcat, cols), map(_->nothing, args)...)
4654
end
55+
eltype(res) <: AbstractVector ? reduce(hcat, res) : reshape(res,1,:), back
4756
end
4857

49-
surevec(A::AbstractArray, M) = vec(A) # to allow f vector -> matrix, by reshaping
50-
surevec(x::Number, M) = _veclike(x, M) # to allow f vector -> scalar, as mapslices does
51-
52-
_veclike(x::T, M) where {T} = fill!(similar(M, T, 1), x) # use similar to preserve CuArrays, #4
53-
_veclike(x::TrackedReal, M) = track(_veclike, x, M)
54-
@grad _veclike(x, M) = _veclike(data(x), M), Δ -> (first(Δ), nothing)
55-
@adjoint _veclike(x, M) = _veclike(x, M), Δ -> (first(Δ), nothing)
56-
5758
"""
5859
maprows(f, M) ≈ mapslices(f, M, dims=2)
5960
6061
Like `mapcols()` but for rows.
6162
"""
62-
maprows(f::Function, M::AbstractMatrix, args...) =
63-
reduce(vcat, map(col -> transpose(surevec(f(col, args...), M)), eachrow(M)))
63+
function maprows(f::Function, M::AbstractMatrix, args...)
64+
res = map(col -> transpose(_vec(f(col, args...))), eachrow(M))
65+
eltype(res) <: AbstractArray ? reduce(vcat, res) : reshape(res,:,1)
66+
end
6467

6568
maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
6669

6770
@grad maprows(f::Function, M::AbstractMatrix, args...) =
68-
∇maprows(map(row -> Tracker.forward(x -> surevec(f(x, args...), M), row), eachrow(data(M))), args)
71+
∇maprows(map(row -> Tracker.forward(x -> _vec(f(x, args...)), row), eachrow(data(M))), args)
6972

7073
@adjoint maprows(f::Function, M::AbstractMatrix, args...) =
71-
∇maprows(map(row -> ZygoteRules.pullback(x -> surevec(f(x, args...), M), row), eachrow(M)), args)
74+
∇maprows(map(row -> ZygoteRules.pullback(x -> _vec(f(x, args...)), row), eachrow(M)), args)
7275

7376
function ∇maprows(forwards, args)
74-
reduce(vcat, map(transposedatafirst, forwards)), Δ -> begin
75-
rows = map((fwd, Δrow) -> data(last(fwd)(Δrow)[1]), forwards, eachrow(data(Δ)))
76-
(nothing, reduce(vcat, transpose.(rows)), map(_->nothing, args)...)
77+
res = map(transposedatafirst, forwards)
78+
function back(Δ)
79+
Δrows = eltype(res) <: AbstractArray ? eachrow(data(Δ)) : vec(data(Δ))
80+
rows = map((fwd, Δrow) -> transpose(data(last(fwd)(Δrow)[1])), forwards, Δrows)
81+
(nothing, reduce(vcat, rows), map(_->nothing, args)...)
7782
end
83+
eltype(res) <: AbstractArray ? reduce(vcat, res) : reshape(res,:,1), back
7884
end
7985

8086
"""
@@ -115,10 +121,13 @@ MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
115121
function _MapCols(map::Function, f::Function, M::Matrix{T}, ::Val{d}, args...) where {T,d}
116122
d == size(M,1) || error("expected M with $d rows")
117123
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(M))
118-
B = map(col -> surevec(f(col, args...), M), A)
124+
B = map(col -> surevec(f(col, args...)), A)
119125
reduce(hcat, B)
120126
end
121127

128+
surevec(A::AbstractArray) = vec(A)
129+
surevec(x::Number) = SVector(x) # simple way to deal with f vector -> scalar
130+
122131
_MapCols(map::Function, f::Function, M::TrackedMatrix, dval, args...) =
123132
track(_MapCols, map, f, M, dval, args...)
124133

@@ -135,7 +144,7 @@ function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::V
135144
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M)))
136145

137146
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
138-
C = bigmap(col -> surevec(f(col + dualcol, args...), M), A)
147+
C = bigmap(col -> surevec(f(col + dualcol, args...)), A)
139148

140149
Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C))
141150

0 commit comments

Comments
 (0)