|
| 1 | + |
| 2 | +module SliceMap |
| 3 | + |
| 4 | +export MapCols, mapcols |
| 5 | + |
| 6 | +#========== Reverse, Eachslice ==========# |
| 7 | + |
| 8 | +""" |
| 9 | + mapcols(f, m::Matrix, args...) = reduce(hcat, f(c, args...) for c in eachcol(M)) |
| 10 | +
|
| 11 | +When `m::TrackedMatrix`, it saves the backward function for each slice. |
| 12 | +""" |
| 13 | +mapcols(f::Function, M::Matrix, args...) = |
| 14 | + reduce(hcat, [ rvec(f(col, args...)) for col in eachcol(M) ]) |
| 15 | + |
| 16 | +using Tracker |
| 17 | +using Tracker: TrackedMatrix, track, @grad, data |
| 18 | + |
| 19 | +mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...) |
| 20 | + |
| 21 | +@grad function mapcols(f::Function, M::TrackedMatrix, args...) |
| 22 | + res = [ Tracker.forward(x -> rvec(f(x, args...)), col) for col in eachcol(data(M)) ] |
| 23 | + fwd = reduce(hcat, data.(first.(res))) |
| 24 | + function back(Δ) |
| 25 | + cols = [ data((last(res[c]))(Δcol)[1]) for (c, Δcol) in enumerate(eachcol(data(Δ))) ] |
| 26 | + ∇M = reduce(hcat, cols) |
| 27 | + (nothing, ∇M, map(_->nothing, args)...) |
| 28 | + end |
| 29 | + fwd, back |
| 30 | +end |
| 31 | + |
| 32 | +using Zygote |
| 33 | +Zygote.@adjoint function mapcols(f::Function, M::Matrix, args...) |
| 34 | + res = [ Zygote.forward(x -> rvec(f(x, args...)), col) for col in eachcol(data(M)) ] |
| 35 | + fwd = reduce(hcat, data.(first.(res))) |
| 36 | + function back(Δ) |
| 37 | + cols = [ data((last(res[c]))(Δcol)[1]) for (c, Δcol) in enumerate(eachcol(data(Δ))) ] |
| 38 | + ∇M = reduce(hcat, cols) |
| 39 | + (nothing, ∇M, map(_->nothing, args)...) |
| 40 | + end |
| 41 | + fwd, back |
| 42 | +end |
| 43 | + |
| 44 | +#========== Forward, Static ==========# |
| 45 | + |
| 46 | +using TensorCast, StaticArrays, WeightedArrays |
| 47 | + |
| 48 | +struct MapCols{d} end |
| 49 | + |
| 50 | +""" |
| 51 | + MapCols{d}(f, m::Matrix, args...) |
| 52 | +
|
| 53 | +Expects `f(::SVector{d}, args...)` and maps this over the columns, `d = size(M,1)`. |
| 54 | +Doesn't expect `f` to return a staticarray, just an array. |
| 55 | +
|
| 56 | +When `m::TrackedMatrix`, it uses `ForwardDiff` to calculate the gradient of each slice. |
| 57 | +The second point of keeping one type parameter is that the dual numbers needed depend on this. |
| 58 | +
|
| 59 | + MapCols{d}(f, m::Weighted, args...) |
| 60 | +Takes `m.weights` along for the ride. |
| 61 | +""" |
| 62 | +MapCols(f::Function, M::WeightedArrays.MaybeWeightedMatrix, args...) = |
| 63 | + MapCols{size(M,1)}(f, M, args...) |
| 64 | + |
| 65 | +MapCols{d}(f::Function, M::WeightedMatrix, args...) where {d} = |
| 66 | + Weighted(MapCols{d}(f, M.array, args...), M.weights, M.opt) |
| 67 | + |
| 68 | +function MapCols{d}(f::Function, M::Matrix, args...) where {d} |
| 69 | + @cast A[c]{r:d} := M[r,c] assert |
| 70 | + reduce(hcat, [ rvec(f(acol, args...)) for acol in A ]) |
| 71 | + |
| 72 | + # TODO: call some function which static-glues if possible... |
| 73 | + # TensorCast.auto_glue(map(col -> rvec(f(col, args...)), A), (:,*)) |
| 74 | + |
| 75 | + # TODO: can I thread this? Is it even safe to do so? |
| 76 | + # https://github.com/mohamed82008/KissThreading.jl |
| 77 | +end |
| 78 | + |
| 79 | +rvec(x::Number) = [x] # to allow for f vector -> scalar, as mapslices does |
| 80 | +rvec(x::StaticArray) = vec(Array(x)) # to avoid creating a giant staticarray, as reduce(hcat would otherwise do |
| 81 | +rvec(A) = vec(A) # LinearAlgebra. |
| 82 | + |
| 83 | + |
| 84 | +using ForwardDiff |
| 85 | + |
| 86 | +MapCols{d}(f::Function, M::TrackedMatrix, args...) where {d} = track(MapCols, f, M, Val(d), args...) |
| 87 | + |
| 88 | +@grad function MapCols(f::Function, M::TrackedMatrix, dval::Val{d}, args...) where {d} |
| 89 | + |
| 90 | + @cast A[c]{r:d} := M.data[r,c] |
| 91 | + dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval)) |
| 92 | + |
| 93 | + C = [ rvec(f(acol .+ dualcol, args...)) for acol in A ] |
| 94 | + |
| 95 | + Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ]) # full is not an SVector here |
| 96 | + |
| 97 | + function back(ΔZ) |
| 98 | + ∇M = similar(data(M)) .+ zero(first(data(ΔZ))) |
| 99 | + @inbounds for c=1:size(M,2) |
| 100 | + part = ForwardDiff.partials.(C[c]) |
| 101 | + for r=1:d |
| 102 | + ∇M[r,c] = 0 |
| 103 | + for i=1:size(ΔZ,1) |
| 104 | + ∇M[r,c] += data(ΔZ)[i,c] * part[i].values[r] |
| 105 | + end |
| 106 | + end |
| 107 | + end |
| 108 | + (nothing, ∇M, nothing, map(_->nothing, args)...) |
| 109 | + end |
| 110 | + |
| 111 | + Z, back |
| 112 | +end |
| 113 | + |
| 114 | +Zygote.@adjoint function MapCols{d}(f::Function, M::Matrix, args...) where {d} # no dval! |
| 115 | + |
| 116 | + @cast A[c]{r:d} := M[r,c] |
| 117 | + dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, Val(d))...), Val(d))) |
| 118 | + |
| 119 | + C = [ rvec(f(acol .+ dualcol, args...)) for acol in A ] |
| 120 | + |
| 121 | + Z = reduce(hcat, [ ForwardDiff.value.(full) for full in C ]) |
| 122 | + |
| 123 | + function back(ΔZ) |
| 124 | + ∇M = similar(data(M)) .+ zero(first(data(ΔZ))) |
| 125 | + @inbounds for c=1:size(M,2) |
| 126 | + part = ForwardDiff.partials.(C[c]) |
| 127 | + for r=1:d |
| 128 | + ∇M[r,c] = 0 |
| 129 | + for i=1:size(ΔZ,1) |
| 130 | + ∇M[r,c] += data(ΔZ)[i,c] * part[i].values[r] |
| 131 | + end |
| 132 | + end |
| 133 | + end |
| 134 | + (nothing, ∇M, map(_->nothing, args)...) # changed! |
| 135 | + end |
| 136 | + |
| 137 | + Z, back |
| 138 | +end |
| 139 | + |
| 140 | +end # module |
0 commit comments