Skip to content

Commit 6457440

Browse files
author
Michael Abbott
committed
partial fix for issue 4
1 parent 1f6f35b commit 6457440

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

src/SliceMap.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ export mapcols, MapCols, maprows, slicemap, tmapcols, ThreadMapCols
66
using JuliennedArrays
77

88
using Tracker
9-
using Tracker: TrackedMatrix, track, @grad, data
9+
using Tracker: TrackedReal, TrackedMatrix, track, @grad, data
1010

1111
using ZygoteRules
1212
using ZygoteRules: pullback, @adjoint
@@ -29,18 +29,15 @@ mapcols(f::Function, M, args...) = _mapcols(map, f, M, args...)
2929
tmapcols(f::Function, M, args...) = _mapcols(threadmap, f, M, args...)
3030

3131
_mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
32-
reduce(hcat, map(col -> surevec(f(col, args...)), eachcol(M)))
33-
34-
surevec(x::Number) = [x] # to allow f vector -> scalar, as mapslices does
35-
surevec(A) = vec(A) # to allow f vector -> matrix, by reshaping
32+
reduce(hcat, map(col -> surevec(f(col, args...), M), eachcol(M)))
3633

3734
_mapcols(map::Function, f::Function, M::TrackedMatrix, args...) = track(_mapcols, map, f, M, args...)
3835

3936
@grad _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
40-
∇mapcols(map, map(col -> Tracker.forward(x -> surevec(f(x, args...)), col), eachcol(data(M))), args...)
37+
∇mapcols(map, map(col -> Tracker.forward(x -> surevec(f(x, args...), M), col), eachcol(data(M))), args...)
4138

4239
@adjoint _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
43-
∇mapcols(map, map(col -> ZygoteRules.pullback(x -> surevec(f(x, args...)), col), eachcol(M)), args)
40+
∇mapcols(map, map(col -> ZygoteRules.pullback(x -> surevec(f(x, args...), M), col), eachcol(M)), args)
4441

4542
function ∇mapcols(bigmap, forwards, args...)
4643
reduce(hcat, map(datafirst, forwards)), Δ -> begin
@@ -49,21 +46,29 @@ function ∇mapcols(bigmap, forwards, args...)
4946
end
5047
end
5148

49+
surevec(A::AbstractArray, M) = vec(A) # to allow f vector -> matrix, by reshaping
50+
surevec(x::Number, M) = _veclike(x, M) # to allow f vector -> scalar, as mapslices does
51+
52+
_veclike(x::T, M) where {T} = fill!(similar(M, T, 1), x) # use similar to preserve CuArrays, #4
53+
_veclike(x::TrackedReal, M) = track(_veclike, x, M)
54+
@grad _veclike(x, M) = _veclike(data(x), M), Δ -> (first(Δ), nothing)
55+
@adjoint _veclike(x, M) = _veclike(x, M), Δ -> (first(Δ), nothing)
56+
5257
"""
5358
maprows(f, M) ≈ mapslices(f, M, dims=2)
5459
5560
Like `mapcols()` but for rows.
5661
"""
5762
maprows(f::Function, M::AbstractMatrix, args...) =
58-
reduce(vcat, map(col -> transpose(surevec(f(col, args...))), eachrow(M)))
63+
reduce(vcat, map(col -> transpose(surevec(f(col, args...), M)), eachrow(M)))
5964

6065
maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
6166

6267
@grad maprows(f::Function, M::AbstractMatrix, args...) =
63-
∇maprows(map(row -> Tracker.forward(x -> surevec(f(x, args...)), row), eachrow(data(M))), args)
68+
∇maprows(map(row -> Tracker.forward(x -> surevec(f(x, args...), M), row), eachrow(data(M))), args)
6469

6570
@adjoint maprows(f::Function, M::AbstractMatrix, args...) =
66-
∇maprows(map(row -> ZygoteRules.pullback(x -> surevec(f(x, args...)), row), eachrow(M)), args)
71+
∇maprows(map(row -> ZygoteRules.pullback(x -> surevec(f(x, args...), M), row), eachrow(M)), args)
6772

6873
function ∇maprows(forwards, args)
6974
reduce(vcat, map(transposedatafirst, forwards)), Δ -> begin
@@ -110,7 +115,7 @@ MapCols{d}(f::Function, M::AbstractMatrix, args...) where {d} =
110115
function _MapCols(map::Function, f::Function, M::Matrix{T}, ::Val{d}, args...) where {T,d}
111116
d == size(M,1) || error("expected M with $d rows")
112117
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(M))
113-
B = map(col -> surevec(f(col, args...)), A)
118+
B = map(col -> surevec(f(col, args...), M), A)
114119
reduce(hcat, B)
115120
end
116121

@@ -130,7 +135,7 @@ function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::V
130135
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M)))
131136

132137
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
133-
C = bigmap(col -> surevec(f(col + dualcol, args...)), A)
138+
C = bigmap(col -> surevec(f(col + dualcol, args...), M), A)
134139

135140
Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C))
136141

0 commit comments

Comments
 (0)