Skip to content

Commit 754d99b

Browse files
author
Michael Abbott
committed
zygote optional
1 parent a3a5bd3 commit 754d99b

File tree

3 files changed

+93
-94
lines changed

3 files changed

+93
-94
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@ version = "0.1.0"
66
[deps]
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
88
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
9+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
910
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1011
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
1112
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1213
WeightedArrays = "379a43df-f81c-573e-83a6-069eb6c11a71"
13-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1414

1515
[compat]
1616
julia = "1"
1717

1818
[extras]
1919
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
20+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2021

2122
[targets]
22-
test = ["Test"]
23+
test = ["Test", "Zygote"]

src/SliceMap.jl

Lines changed: 27 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ module SliceMap
33

44
export mapcols, MapCols, maprows, slicemap
55

6-
using MacroTools, Tracker, Zygote, WeightedArrays
6+
using MacroTools, Requires, WeightedArrays, TensorCast, Tracker
7+
78
using Tracker: TrackedMatrix, track, @grad, data
8-
using Zygote: @adjoint, _zero
99

1010
#========== Reverse, Eachslice ==========#
1111

@@ -35,9 +35,6 @@ mapcols(f::Function, M::TrackedMatrix, args...) = track(mapcols, f, M, args...)
3535
@grad mapcols(f::Function, M::AbstractMatrix, args...) =
3636
∇mapcols([ Tracker.forward(x -> surevec(f(x, args...)), col) for col in eachcol(data(M)) ], args)
3737

38-
@adjoint mapcols(f::Function, M::AbstractMatrix, args...) =
39-
∇mapcols([ Zygote.forward(x -> surevec(f(x, args...)), col) for col in eachcol(M) ], args)
40-
4138
function ∇mapcols(forwards, args)
4239
reduce(hcat, data.(first.(forwards))), Δ -> begin
4340
cols = [ data(last(fwd)(Δcol)[1]) for (fwd, Δcol) in zip(forwards, eachcol(data(Δ))) ]
@@ -58,16 +55,27 @@ maprows(f::Function, M::TrackedMatrix, args...) = track(maprows, f, M, args...)
5855
@grad maprows(f::Function, M::AbstractMatrix, args...) =
5956
∇maprows([ Tracker.forward(x -> surevec(f(x, args...)), row) for row in eachrow(data(M)) ], args)
6057

61-
@adjoint maprows(f::Function, M::AbstractMatrix, args...) =
62-
∇maprows([ Zygote.forward(x -> surevec(f(x, args...)), row) for row in eachrow(M) ], args)
63-
6458
function ∇maprows(forwards, args)
6559
reduce(vcat, map(transposedatafirst, forwards)), Δ -> begin
6660
rows = [ data(last(fwd)(Δrow)[1]) for (fwd, Δrow) in zip(forwards, eachrow(data(Δ))) ]
6761
(nothing, reduce(vcat, transpose.(rows)), map(_->nothing, args)...)
6862
end
6963
end
7064

65+
"""
66+
slicemap(f, A; dims) ≈ mapslices(f, A; dims)
67+
68+
Like `mapcols()`, but for any slice. The function `f` must preserve shape,
69+
e.g. `dims=(2,4)` then `f` must map matrices to matrices.
70+
71+
The gradient is for Zygote only.
72+
"""
73+
function slicemap(f::Function, A::AbstractArray{T,N}, args...; dims) where {T,N}
74+
code = ntuple(d -> d in dims ? (:) : (*), N)
75+
B = TensorCast.sliceview(A, code)
76+
C = [ f(slice, args...) for slice in B ]
77+
TensorCast.glue(C, code)
78+
end
7179

7280
#========== Forward, Static ==========#
7381

@@ -96,36 +104,18 @@ function _MapCols(f::Function, M::Matrix{T}, ::Val{d}, args...) where {T,d}
96104
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(M))
97105
B = map(col -> surevec(f(col, args...)), A)
98106
reduce(hcat, B)
99-
# maybestaticgluecols(B)
100-
end
101-
102-
# surevec(x::MArray) = Array(x) # avoid making a huge MArray, ad
103-
# surevecS(x::Number) = @SVector [x]
104-
# surevecS(A) = vec(A) # like surevec
105-
106-
function maybestaticgluecols(B)
107-
TB = eltype(B)
108-
if TB <: SArray
109-
C = collect(reshape(reinterpret(eltype(TB), B),:,length(B)))
110-
elseif TB <: MArray
111-
C = reduce(hcat, Array.(B))
112-
else
113-
C = reduce(hcat, B)
114-
end
115107
end
116108

117109
_MapCols(f::Function, M::TrackedMatrix, dval, args...) = track(_MapCols, f, M, dval, args...)
118110

119111
@grad _MapCols(f::Function, M::TrackedMatrix, dval, args...) = ∇MapCols(f, M, dval, args...)
120112

121-
@adjoint _MapCols(f::Function, M::Matrix, dval, args...) = ∇MapCols(f, M, dval, args...)
122-
123113
function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) where {T,d}
124114
d == size(M,1) || error("expected M with $d columns")
125115
A = reinterpret(SArray{Tuple{d}, T, 1, d}, vec(data(M)))
126116

127117
dualcol = SVector(ntuple(j->ForwardDiff.Dual(0, ntuple(i->i==j ? 1 : 0, dval)...), dval))
128-
C = map(col -> surevec(f(col .+ dualcol, args...)), A)
118+
C = map(col -> surevec(f(col + dualcol, args...)), A)
129119

130120
Z = reduce(hcat, map(col -> ForwardDiff.value.(col), C))
131121

@@ -144,23 +134,27 @@ function ∇MapCols(f::Function, M::AbstractMatrix{T}, dval::Val{d}, args...) wh
144134
Z, back
145135
end
146136

137+
#========== Gradients for Zygote ==========#
138+
139+
# @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
140+
# end
147141

148-
#========== Gradient for eachslice / reduce ==========#
142+
@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("zygote.jl")
143+
144+
#========== Experimenting with gradients for for eachslice / reduce ==========#
149145

150146
export gluecol, collecteachcol
151147
export mapcols2, mapcols4, mapcols5, mapcols6, mapcols7
152148

153149
gluecol(V::AbstractVector{<:AbstractVector}) = reduce(hcat, V)
154150

151+
#=
155152
gluecol(V::AbstractVector{<:TrackedVector}) = track(gluecol, V)
156153
157154
@grad function gluecol(V::AbstractVector)
158155
gluecol(data.(V)), ΔM -> (collect(eachcol(data(ΔM))),) # doesn't work
159156
end
160-
161-
Zygote.@adjoint function gluecol(V::AbstractVector)
162-
gluecol(V), ΔM -> (collect(eachcol(ΔM)),) # does work!
163-
end
157+
=#
164158

165159
function mapcols2(f, A)
166160
cols = [A[:,c] for c=1:size(A,2)]
@@ -170,34 +164,18 @@ end
170164

171165
# Apply that straight to reduce(hcat,...)
172166

173-
Zygote.@adjoint function Base.reduce(::typeof(hcat), V::AbstractVector{<:AbstractVector})
174-
reduce(hcat, V), dV -> (nothing, collect(eachcol(dV)),)
175-
end
176-
177167
function mapcols4(f, A)
178168
cols = [view(A,:,c) for c=1:size(A,2)]
179169
res = map(f, cols)
180170
reduce(hcat, res)
181171
end
182172

183-
# Zygote doesn't understand views, but easy to fix:
184-
# https://github.com/FluxML/Zygote.jl/issues/52
185-
# now https://github.com/FluxML/Zygote.jl/pull/219
186-
187-
Zygote.@adjoint function view(x::AbstractArray, inds...; kwargs...)
188-
view(x, inds...; kwargs...), dy -> begin
189-
dx = _zero(x)
190-
copyto!(view(dx, inds...; kwargs...), dy)
191-
(dx, map(_->nothing, inds)...)
192-
end
193-
end
194-
195173
# Surprisingly dy for eachcol seems to know the answer?
196174
# typeof(dy) = NamedTuple{(:f, :iter),Tuple{NamedTuple{(:A,),Tuple{Array{Float64,2}}},Array{Nothing,1}}}
197175
# dy = (f = (A = [47.9325 51.3781
198176
# Which means this works... but uses as much memory as gradient of array of views:
199177

200-
#=Zygote.@adjoint function eachcol(x::AbstractMatrix)
178+
#=@adjoint function eachcol(x::AbstractMatrix)
201179
eachcol(x), dy -> (dy.f.A,) #= begin
202180
@show typeof(dy) dy
203181
dx = zero(x) .+ 0.0 # zeros(eltype(dy), size(x))
@@ -216,14 +194,6 @@ end
216194

217195
collecteachcol(x) = collect(eachcol(x))
218196

219-
@adjoint function collecteachcol(x)
220-
collecteachcol(x), dy -> begin
221-
dx = _zero(x)
222-
foreach(copyto!, collecteachcol(dx), dy)
223-
(dx,)
224-
end
225-
end
226-
227197
function mapcols6(f, A)
228198
cols = collecteachcol(A)
229199
res = map(f, cols)
@@ -240,39 +210,4 @@ end
240210
# @adjoint Base.collect(x) = collect(x), Δ -> (Δ,)
241211

242212

243-
#========== Gradients for TensorCast's functions ==========#
244-
245-
using TensorCast
246-
247-
@adjoint function TensorCast.sliceview(A::AbstractArray, code::Tuple)
248-
TensorCast.sliceview(A, code), Δ -> begin
249-
dA = _zero(A)
250-
foreach(copyto!, TensorCast.sliceview(dA, code), Δ)
251-
(dA, nothing)
252-
end
253-
end
254-
255-
@adjoint function TensorCast.red_glue(A::AbstractArray, code::Tuple)
256-
TensorCast.red_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
257-
end
258-
259-
@adjoint function TensorCast.copy_glue(A::AbstractArray, code::Tuple)
260-
TensorCast.copy_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
261-
end
262-
263-
"""
264-
slicemap(f, A; dims) ≈ mapslices(f, A; dims)
265-
266-
Like `mapcols()`, but for any slice. The function `f` must preserve shape,
267-
e.g. `dims=(2,4)` then `f` must map matrices to matrices.
268-
269-
The gradient is for Zygote only.
270-
"""
271-
function slicemap(f::Function, A::AbstractArray{T,N}, args...; dims) where {T,N}
272-
code = ntuple(d -> d in dims ? (:) : (*), N)
273-
B = TensorCast.sliceview(A, code)
274-
C = [ f(slice, args...) for slice in B ]
275-
TensorCast.glue(C, code)
276-
end
277-
278213
end # module

src/zygote.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
2+
using .Zygote
3+
using .Zygote: @adjoint, _zero, forward
4+
5+
#===== mapcols, maprows, MapCols =====#
6+
7+
@adjoint mapcols(f::Function, M::AbstractMatrix, args...) =
8+
∇mapcols([ forward(x -> surevec(f(x, args...)), col) for col in eachcol(M) ], args)
9+
10+
@adjoint maprows(f::Function, M::AbstractMatrix, args...) =
11+
∇maprows([ forward(x -> surevec(f(x, args...)), row) for row in eachrow(M) ], args)
12+
13+
@adjoint _MapCols(f::Function, M::Matrix, dval, args...) = ∇MapCols(f, M, dval, args...)
14+
15+
#===== TensorCast =====#
16+
17+
@adjoint function TensorCast.sliceview(A::AbstractArray, code::Tuple)
18+
TensorCast.sliceview(A, code), Δ -> begin
19+
dA = _zero(A)
20+
foreach(copyto!, TensorCast.sliceview(dA, code), Δ)
21+
(dA, nothing)
22+
end
23+
end
24+
25+
@adjoint function TensorCast.red_glue(A::AbstractArray, code::Tuple)
26+
TensorCast.red_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
27+
end
28+
29+
@adjoint function TensorCast.copy_glue(A::AbstractArray, code::Tuple)
30+
TensorCast.copy_glue(A, code), Δ -> (TensorCast.sliceview(Δ, code), nothing)
31+
end
32+
33+
#===== Misc Base =====#
34+
35+
@adjoint function Base.reduce(::typeof(hcat), V::AbstractVector{<:AbstractVector})
36+
reduce(hcat, V), dV -> (nothing, collect(eachcol(dV)),)
37+
end
38+
39+
# Zygote doesn't understand views, but easy to fix:
40+
# https://github.com/FluxML/Zygote.jl/issues/52
41+
# now https://github.com/FluxML/Zygote.jl/pull/219
42+
43+
@adjoint function view(x::AbstractArray, inds...; kwargs...)
44+
view(x, inds...; kwargs...), dy -> begin
45+
dx = _zero(x)
46+
copyto!(view(dx, inds...; kwargs...), dy)
47+
(dx, map(_->nothing, inds)...)
48+
end
49+
end
50+
51+
#===== Misc experiments =====#
52+
53+
@adjoint function gluecol(V::AbstractVector)
54+
gluecol(V), ΔM -> (collect(eachcol(ΔM)),) # does work!
55+
end
56+
57+
@adjoint function collecteachcol(x)
58+
collecteachcol(x), dy -> begin
59+
dx = _zero(x)
60+
foreach(copyto!, collecteachcol(dx), dy)
61+
(dx,)
62+
end
63+
end

0 commit comments

Comments
 (0)