Skip to content

Commit 1f9b577

Browse files
feat: add and test adjoints for broadcast arithmetic on VoA
1 parent 1ecc966 commit 1f9b577

File tree

2 files changed

+132
-3
lines changed

2 files changed

+132
-3
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ end
100100
end
101101
end
102102

103+
@adjoint function Base.copy(u::VectorOfArray)
104+
copy(u),
105+
y -> (copy(y),)
106+
end
107+
103108
@adjoint function DiffEqArray(u, t)
104109
DiffEqArray(u, t),
105110
y -> begin
@@ -117,19 +122,122 @@ end
117122
A.x, literal_ArrayPartition_x_adjoint
118123
end
119124

120-
@adjoint function Array(VA::AbstractVectorOfArray)
125+
@adjoint function Base.Array(VA::AbstractVectorOfArray)
121126
Array(VA),
122127
y -> (Array(y),)
123128
end
124129

130+
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
131+
view(A, I...),
132+
y -> (view(y, I...), ntuple(_ -> nothing, length(I))...)
133+
end
125134

126135
ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
127136

128-
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x)
137+
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x::Union{AbstractArray,AbstractVectorOfArray})
129138
arr = reshape(x, p.sz)
130139
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
131140
end
132141

142+
@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray, y::Union{Zygote.Numeric, AbstractVectorOfArray})
143+
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
144+
end
145+
@adjoint function Broadcast.broadcasted(::typeof(+), x::Zygote.Numeric, y::AbstractVectorOfArray)
146+
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
147+
end
148+
149+
_minus(Δ) = .-Δ
150+
_minus(::Nothing) = nothing
151+
152+
@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})
153+
x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
154+
end
155+
@adjoint function Broadcast.broadcasted(::typeof(*), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})
156+
(
157+
x.*y,
158+
Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x)))
159+
)
160+
end
161+
@adjoint function Broadcast.broadcasted(::typeof(/), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric})
162+
res = x ./ y
163+
res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)))
164+
end
165+
@adjoint function Broadcast.broadcasted(::typeof(-), x::Zygote.Numeric, y::AbstractVectorOfArray)
166+
x .- y, Δ -> (nothing, Zygote.unbroadcast(x, Δ), _minus(Zygote.unbroadcast(y, Δ)))
167+
end
168+
@adjoint function Broadcast.broadcasted(::typeof(*), x::Zygote.Numeric, y::AbstractVectorOfArray)
169+
(
170+
x.*y,
171+
Δ -> (nothing, Zygote.unbroadcast(x, Δ .* conj.(y)), Zygote.unbroadcast(y, Δ .* conj.(x)))
172+
)
173+
end
174+
@adjoint function Broadcast.broadcasted(::typeof(/), x::Zygote.Numeric, y::AbstractVectorOfArray)
175+
res = x ./ y
176+
res, Δ -> (nothing, Zygote.unbroadcast(x, Δ ./ conj.(y)), Zygote.unbroadcast(y, .-Δ .* conj.(res ./ y)))
177+
end
178+
@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray)
179+
.-x, Δ -> (nothing, _minus(Δ))
180+
end
181+
182+
@adjoint function Broadcast.broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::AbstractVectorOfArray, exp::Val{p}) where p
183+
y = Base.literal_pow.(^, x, exp)
184+
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing)
185+
end
186+
187+
@adjoint Broadcast.broadcasted(::typeof(identity), x::AbstractVectorOfArray) = x, Δ -> (nothing, Δ)
188+
189+
@adjoint function Broadcast.broadcasted(::typeof(tanh), x::AbstractVectorOfArray)
190+
y = tanh.(x)
191+
y, ȳ -> (nothing, ȳ .* conj.(1 .- y.^2))
192+
end
193+
194+
@adjoint Broadcast.broadcasted(::typeof(conj), x::AbstractVectorOfArray) =
195+
conj.(x), z̄ -> (nothing, conj.(z̄))
196+
197+
@adjoint Broadcast.broadcasted(::typeof(real), x::AbstractVectorOfArray) =
198+
real.(x), z̄ -> (nothing, real.(z̄))
199+
200+
@adjoint Broadcast.broadcasted(::typeof(imag), x::AbstractVectorOfArray) =
201+
imag.(x), z̄ -> (nothing, im .* real.(z̄))
202+
203+
@adjoint Broadcast.broadcasted(::typeof(abs2), x::AbstractVectorOfArray) =
204+
abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x)
205+
206+
@adjoint function Broadcast.broadcasted(::typeof(+), a::AbstractVectorOfArray{<:Number}, b::Bool)
207+
y = b === false ? a : a .+ b
208+
y, Δ -> (nothing, Δ, nothing)
209+
end
210+
@adjoint function Broadcast.broadcasted(::typeof(+), b::Bool, a::AbstractVectorOfArray{<:Number})
211+
y = b === false ? a : b .+ a
212+
y, Δ -> (nothing, nothing, Δ)
213+
end
214+
215+
@adjoint function Broadcast.broadcasted(::typeof(-), a::AbstractVectorOfArray{<:Number}, b::Bool)
216+
y = b === false ? a : a .- b
217+
y, Δ -> (nothing, Δ, nothing)
218+
end
219+
@adjoint function Broadcast.broadcasted(::typeof(-), b::Bool, a::AbstractVectorOfArray{<:Number})
220+
b .- a, Δ -> (nothing, nothing, .-Δ)
221+
end
222+
223+
@adjoint function Broadcast.broadcasted(::typeof(*), a::AbstractVectorOfArray{<:Number}, b::Bool)
224+
if b === false
225+
zero(a), Δ -> (nothing, zero(Δ), nothing)
226+
else
227+
a, Δ -> (nothing, Δ, nothing)
228+
end
229+
end
230+
@adjoint function Broadcast.broadcasted(::typeof(*), b::Bool, a::AbstractVectorOfArray{<:Number})
231+
if b === false
232+
zero(a), Δ -> (nothing, nothing, zero(Δ))
233+
else
234+
a, Δ -> (nothing, nothing, Δ)
235+
end
236+
end
237+
238+
@adjoint Broadcast.broadcasted(::Type{T}, x::AbstractVectorOfArray) where {T<:Number} =
239+
T.(x), ȳ -> (nothing, Zygote._project(x, ȳ),)
240+
133241
function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
134242
N = ndims(x̄)
135243
if length(x) == length(x̄)

test/adjoints.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,27 @@ end
3939

4040
function loss7(x)
4141
_x = VectorOfArray([x .* i for i in 1:5])
42-
return sum(abs2, x .- 1)
42+
return sum(abs2, _x .- 1)
43+
end
44+
45+
# use a bunch of broadcasts to test all the adjoints
46+
function loss8(x)
47+
_x = VectorOfArray([x .* i for i in 1:5])
48+
res = copy(_x)
49+
res = res .+ _x
50+
res = res .+ 1
51+
res = res .* _x
52+
res = res .* 2.0
53+
res = res .* res
54+
res = res ./ 2.0
55+
res = res ./ _x
56+
res = 3.0 .- res
57+
res = .-res
58+
res = identity.(Base.literal_pow.(^, res, Val(2)))
59+
res = tanh.(res)
60+
res = res .+ im .* res
61+
res = conj.(res) .+ real.(res) .+ imag.(res) .+ abs2.(res)
62+
return sum(abs2, res)
4363
end
4464

4565
x = float.(6:10)
@@ -51,3 +71,4 @@ loss(x)
5171
@test Zygote.gradient(loss5, x)[1] == ForwardDiff.gradient(loss5, x)
5272
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
5373
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)
74+
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)

0 commit comments

Comments
 (0)