Skip to content

Commit ff296ce

Browse files
authored
Merge pull request #2 from mcabbott/zygoterules
ZygoteRules
2 parents 47e474c + bb217cd commit ff296ce

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SliceMap"
22
uuid = "82cb661a-3f19-5665-9e27-df437c7e54c8"
33
authors = ["Michael Abbott"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -11,6 +11,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1111
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1212
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
1313
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
14+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1415

1516
[compat]
1617
julia = "1"

src/SliceMap.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ using MacroTools, Requires, TensorCast, JuliennedArrays
88
using Tracker
99
using Tracker: TrackedMatrix, track, @grad, data
1010

11+
using ZygoteRules
12+
using ZygoteRules: pullback, @adjoint
13+
1114
#========== Reverse, Eachslice ==========#
1215

1316
"""
@@ -36,6 +39,9 @@ _mapcols(map::Function, f::Function, M::TrackedMatrix, args...) = track(_mapcols
3639
@grad _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
3740
∇mapcols(map, map(col -> Tracker.forward(x -> surevec(f(x, args...)), col), eachcol(data(M))), args...)
3841

42+
@adjoint _mapcols(map::Function, f::Function, M::AbstractMatrix, args...) =
43+
∇mapcols(map, map(col -> ZygoteRules.pullback(x -> surevec(f(x, args...)), col), eachcol(M)), args)
44+
3945
function ∇mapcols(bigmap, forwards, args...)
4046
reduce(hcat, map(datafirst, forwards)), Δ -> begin
4147
cols = bigmap((fwd, Δcol) -> data(last(fwd)(Δcol)[1]), forwards, eachcol(data(Δ)))
@@ -56,6 +62,9 @@ maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
5662
@grad maprows(f::Function, M::AbstractMatrix, args...) =
5763
∇maprows(map(row -> Tracker.forward(x -> surevec(f(x, args...)), row), eachrow(data(M))), args)
5864

65+
@adjoint maprows(f::Function, M::AbstractMatrix, args...) =
66+
∇maprows(map(row -> ZygoteRules.pullback(x -> surevec(f(x, args...)), row), eachrow(M)), args)
67+
5968
function ∇maprows(forwards, args)
6069
reduce(vcat, map(transposedatafirst, forwards)), Δ -> begin
6170
rows = map((fwd, Δrow) -> data(last(fwd)(Δrow)[1]), forwards, eachrow(data(Δ)))
@@ -77,6 +86,7 @@ function slicemap(f::Function, A::AbstractArray{T,N}, args...; dims) where {T,N}
7786
C = [ f(slice, args...) for slice in B ]
7887
TensorCast.glue(C, code)
7988
end
89+
# TODO switch to JuliennedArrays, then rm TensorCast dep
8090

8191
#========== Forward, Static ==========#
8292

@@ -111,6 +121,9 @@ _MapCols(map::Function, f::Function, M::TrackedMatrix, dval, args...) =
111121
@grad _MapCols(map::Function, f::Function, M::TrackedMatrix, dval, args...) =
112122
∇MapCols(map, f, M, dval, args...)
113123

124+
@adjoint _MapCols(map::Function, f::Function, M::Matrix, dval, args...) =
125+
∇MapCols(map, f, M, dval, args...)
126+
114127
function ∇MapCols(bigmap::Function, f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) where {T,d}
115128
d == size(M,1) || error("expected M with $d rows")
116129
k = size(M,2)
@@ -142,7 +155,34 @@ end
142155

143156
# @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
144157

145-
@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("zygote.jl")
158+
# @init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("zygote.jl")
159+
# Now using ZygoteRules instead, mapcols etc above.
160+
161+
#= TensorCast =#
162+
# These could move there, TODO
163+
164+
@adjoint TensorCast.sliceview(A::AbstractArray, code::Tuple) =
165+
TensorCast.sliceview(A, code), Δ -> (TensorCast.glue(Δ, code), nothing)
166+
167+
@adjoint TensorCast.red_glue(A::AbstractArray, code::Tuple) =
168+
TensorCast.red_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
169+
170+
@adjoint TensorCast.copy_glue(A::AbstractArray, code::Tuple) =
171+
TensorCast.copy_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
172+
173+
#= JuliennedArrays =#
174+
175+
@adjoint JuliennedArrays.Slices(whole, along...) =
176+
Slices(whole, along...), Δ -> (Align(Δ, along...), map(_->nothing, along)...)
177+
178+
@adjoint JuliennedArrays.Align(whole, along...) =
179+
Align(whole, along...), Δ -> (Slices(Δ, along...), map(_->nothing, along)...)
180+
181+
#= Base =#
182+
183+
@adjoint Base.reduce(::typeof(hcat), V::AbstractVector{<:AbstractVector}) =
184+
reduce(hcat, V), dV -> (nothing, collect(eachcol(dV)),)
185+
146186

147187
#========== Experimenting with gradients for for eachslice / reduce ==========#
148188

@@ -159,6 +199,9 @@ gluecol(V::AbstractVector{<:TrackedVector}) = track(gluecol, V)
159199
end
160200
=#
161201

202+
@adjoint gluecol(V::AbstractVector) =
203+
gluecol(V), ΔM -> (collect(eachcol(ΔM)),) # does work!
204+
162205
function mapcols2(f, A)
163206
cols = [A[:,c] for c=1:size(A,2)]
164207
res = f.(cols)
@@ -197,6 +240,14 @@ end
197240

198241
collecteachcol(x) = collect(eachcol(x))
199242

243+
@adjoint function collecteachcol(x)
244+
collecteachcol(x), dy -> begin
245+
dx = _zero(x) # _zero is not in ZygoteRules, TODO
246+
foreach(copyto!, collecteachcol(dx), dy)
247+
(dx,)
248+
end
249+
end
250+
200251
function mapcols6(f, A)
201252
cols = collecteachcol(A)
202253
res = map(f, cols)

src/zygote.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11

2+
# Note to self -- ZygoteRules doesn't have forward, only @adjoint, so it's no help.
3+
# Later: it now has pullback, so that should work? Once registered... all copied in.
4+
25
using .Zygote
36
using .Zygote: @adjoint, _zero, forward
47

0 commit comments

Comments
 (0)