Skip to content

Commit bb89ff4

Browse files
fix: add support for autodiff through VoA broadcast
1 parent cb9d12d commit bb89ff4

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ else
1111
end
1212

1313
# Define a new species of projection operator for this type:
14-
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
14+
# ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()
1515

1616
function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
1717
xs::AbstractVectorOfArray)
@@ -117,4 +117,53 @@ end
117117
A.x, literal_ArrayPartition_x_adjoint
118118
end
119119

120+
@adjoint function Array(VA::AbstractVectorOfArray)
121+
Array(VA),
122+
y -> (Array(y),)
120123
end
124+
125+
126+
ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
127+
128+
function (p::ChainRulesCore.ProjectTo{VectorOfArray})(x)
129+
arr = reshape(x, p.sz)
130+
return VectorOfArray([arr[:, i] for i in 1:p.sz[end]])
131+
end
132+
133+
function Zygote.unbroadcast(x::AbstractVectorOfArray, x̄)
134+
N = ndims(x̄)
135+
if length(x) == length(x̄)
136+
Zygote._project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
137+
else
138+
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄))
139+
Zygote._project(x, Zygote.accum_sum(x̄; dims = dims))
140+
end
141+
end
142+
143+
@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b) where {F} = _broadcast_generic(__context__, f, a, b)
144+
@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b)
145+
@adjoint Broadcast.broadcasted(::Broadcast.AbstractArrayStyle, f::F, a::AbstractVectorOfArray, b::AbstractVectorOfArray) where {F} = _broadcast_generic(__context__, f, a, b)
146+
147+
@inline function _broadcast_generic(__context__, f::F, args...) where {F}
148+
T = Broadcast.combine_eltypes(f, args)
149+
# Avoid generic broadcasting in two easy cases:
150+
if T == Bool
151+
return (f.(args...), _ -> nothing)
152+
elseif T <: Union{Real, Complex} && isconcretetype(T) && Zygote._dual_purefun(F) && all(Zygote._dual_safearg, args) && !Zygote.isderiving()
153+
return Zygote.broadcast_forward(f, args...)
154+
end
155+
len = Zygote.inclen(args)
156+
y∂b = Zygote._broadcast((x...) -> Zygote._pullback(__context__, f, x...), args...)
157+
y = broadcast(first, y∂b)
158+
function ∇broadcasted(ȳ)
159+
y∂b = y∂b isa AbstractVectorOfArray ? Iterators.flatten(y∂b.u) : y∂b
160+
ȳ = ȳ isa AbstractVectorOfArray ? Iterators.flatten.u) : ȳ
161+
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
162+
getters = ntuple(i -> Zygote.StaticGetter{i}(), len)
163+
dxs = map(g -> Zygote.collapse_nothings(map(g, dxs_zip)), getters)
164+
(nothing, Zygote.accum_sum(dxs[1]), map(Zygote.unbroadcast, args, Base.tail(dxs))...)
165+
end
166+
return y, ∇broadcasted
167+
end
168+
169+
end # module

0 commit comments

Comments
 (0)