100100 end
101101end
102102
103+ @adjoint function Base. copy (u:: VectorOfArray )
104+ copy (u),
105+ y -> (copy (y),)
106+ end
107+
103108@adjoint function DiffEqArray (u, t)
104109 DiffEqArray (u, t),
105110 y -> begin
@@ -117,19 +122,122 @@ end
117122 A. x, literal_ArrayPartition_x_adjoint
118123end
119124
120- @adjoint function Array (VA:: AbstractVectorOfArray )
125+ @adjoint function Base . Array (VA:: AbstractVectorOfArray )
121126 Array (VA),
122127 y -> (Array (y),)
123128end
124129
130+ @adjoint function Base. view (A:: AbstractVectorOfArray , I... )
131+ view (A, I... ),
132+ y -> (view (y, I... ), ntuple (_ -> nothing , length (I))... )
133+ end
125134
126135ChainRulesCore. ProjectTo (a:: AbstractVectorOfArray ) = ChainRulesCore. ProjectTo {VectorOfArray} ((sz = size (a)))
127136
128- function (p:: ChainRulesCore.ProjectTo{VectorOfArray} )(x)
137+ function (p:: ChainRulesCore.ProjectTo{VectorOfArray} )(x:: Union{AbstractArray,AbstractVectorOfArray} )
129138 arr = reshape (x, p. sz)
130139 return VectorOfArray ([arr[:, i] for i in 1 : p. sz[end ]])
131140end
132141
142+ @adjoint function Broadcast. broadcasted (:: typeof (+ ), x:: AbstractVectorOfArray , y:: Union{Zygote.Numeric, AbstractVectorOfArray} )
143+ broadcast (+ , x, y), ȳ -> (nothing , map (x -> Zygote. unbroadcast (x, ȳ), (x, y))... )
144+ end
145+ @adjoint function Broadcast. broadcasted (:: typeof (+ ), x:: Zygote.Numeric , y:: AbstractVectorOfArray )
146+ broadcast (+ , x, y), ȳ -> (nothing , map (x -> Zygote. unbroadcast (x, ȳ), (x, y))... )
147+ end
148+
149+ _minus (Δ) = .- Δ
150+ _minus (:: Nothing ) = nothing
151+
152+ @adjoint function Broadcast. broadcasted (:: typeof (- ), x:: AbstractVectorOfArray , y:: Union{AbstractVectorOfArray, Zygote.Numeric} )
153+ x .- y, Δ -> (nothing , Zygote. unbroadcast (x, Δ), _minus (Zygote. unbroadcast (y, Δ)))
154+ end
155+ @adjoint function Broadcast. broadcasted (:: typeof (* ), x:: AbstractVectorOfArray , y:: Union{AbstractVectorOfArray, Zygote.Numeric} )
156+ (
157+ x.* y,
158+ Δ -> (nothing , Zygote. unbroadcast (x, Δ .* conj .(y)), Zygote. unbroadcast (y, Δ .* conj .(x)))
159+ )
160+ end
161+ @adjoint function Broadcast. broadcasted (:: typeof (/ ), x:: AbstractVectorOfArray , y:: Union{AbstractVectorOfArray, Zygote.Numeric} )
162+ res = x ./ y
163+ res, Δ -> (nothing , Zygote. unbroadcast (x, Δ ./ conj .(y)), Zygote. unbroadcast (y, .- Δ .* conj .(res ./ y)))
164+ end
165+ @adjoint function Broadcast. broadcasted (:: typeof (- ), x:: Zygote.Numeric , y:: AbstractVectorOfArray )
166+ x .- y, Δ -> (nothing , Zygote. unbroadcast (x, Δ), _minus (Zygote. unbroadcast (y, Δ)))
167+ end
168+ @adjoint function Broadcast. broadcasted (:: typeof (* ), x:: Zygote.Numeric , y:: AbstractVectorOfArray )
169+ (
170+ x.* y,
171+ Δ -> (nothing , Zygote. unbroadcast (x, Δ .* conj .(y)), Zygote. unbroadcast (y, Δ .* conj .(x)))
172+ )
173+ end
174+ @adjoint function Broadcast. broadcasted (:: typeof (/ ), x:: Zygote.Numeric , y:: AbstractVectorOfArray )
175+ res = x ./ y
176+ res, Δ -> (nothing , Zygote. unbroadcast (x, Δ ./ conj .(y)), Zygote. unbroadcast (y, .- Δ .* conj .(res ./ y)))
177+ end
178+ @adjoint function Broadcast. broadcasted (:: typeof (- ), x:: AbstractVectorOfArray )
179+ .- x, Δ -> (nothing , _minus (Δ))
180+ end
181+
182+ @adjoint function Broadcast. broadcasted (:: typeof (Base. literal_pow), :: typeof (^ ), x:: AbstractVectorOfArray , exp:: Val{p} ) where p
183+ y = Base. literal_pow .(^ , x, exp)
184+ y, ȳ -> (nothing , nothing , ȳ .* p .* conj .(x .^ (p - 1 )), nothing )
185+ end
186+
187+ @adjoint Broadcast. broadcasted (:: typeof (identity), x:: AbstractVectorOfArray ) = x, Δ -> (nothing , Δ)
188+
189+ @adjoint function Broadcast. broadcasted (:: typeof (tanh), x:: AbstractVectorOfArray )
190+ y = tanh .(x)
191+ y, ȳ -> (nothing , ȳ .* conj .(1 .- y.^ 2 ))
192+ end
193+
194+ @adjoint Broadcast. broadcasted (:: typeof (conj), x:: AbstractVectorOfArray ) =
195+ conj .(x), z̄ -> (nothing , conj .(z̄))
196+
197+ @adjoint Broadcast. broadcasted (:: typeof (real), x:: AbstractVectorOfArray ) =
198+ real .(x), z̄ -> (nothing , real .(z̄))
199+
200+ @adjoint Broadcast. broadcasted (:: typeof (imag), x:: AbstractVectorOfArray ) =
201+ imag .(x), z̄ -> (nothing , im .* real .(z̄))
202+
203+ @adjoint Broadcast. broadcasted (:: typeof (abs2), x:: AbstractVectorOfArray ) =
204+ abs2 .(x), z̄ -> (nothing , 2 .* real .(z̄) .* x)
205+
206+ @adjoint function Broadcast. broadcasted (:: typeof (+ ), a:: AbstractVectorOfArray{<:Number} , b:: Bool )
207+ y = b === false ? a : a .+ b
208+ y, Δ -> (nothing , Δ, nothing )
209+ end
210+ @adjoint function Broadcast. broadcasted (:: typeof (+ ), b:: Bool , a:: AbstractVectorOfArray{<:Number} )
211+ y = b === false ? a : b .+ a
212+ y, Δ -> (nothing , nothing , Δ)
213+ end
214+
215+ @adjoint function Broadcast. broadcasted (:: typeof (- ), a:: AbstractVectorOfArray{<:Number} , b:: Bool )
216+ y = b === false ? a : a .- b
217+ y, Δ -> (nothing , Δ, nothing )
218+ end
219+ @adjoint function Broadcast. broadcasted (:: typeof (- ), b:: Bool , a:: AbstractVectorOfArray{<:Number} )
220+ b .- a, Δ -> (nothing , nothing , .- Δ)
221+ end
222+
223+ @adjoint function Broadcast. broadcasted (:: typeof (* ), a:: AbstractVectorOfArray{<:Number} , b:: Bool )
224+ if b === false
225+ zero (a), Δ -> (nothing , zero (Δ), nothing )
226+ else
227+ a, Δ -> (nothing , Δ, nothing )
228+ end
229+ end
230+ @adjoint function Broadcast. broadcasted (:: typeof (* ), b:: Bool , a:: AbstractVectorOfArray{<:Number} )
231+ if b === false
232+ zero (a), Δ -> (nothing , nothing , zero (Δ))
233+ else
234+ a, Δ -> (nothing , nothing , Δ)
235+ end
236+ end
237+
238+ @adjoint Broadcast. broadcasted (:: Type{T} , x:: AbstractVectorOfArray ) where {T<: Number } =
239+ T .(x), ȳ -> (nothing , Zygote. _project (x, ȳ),)
240+
133241function Zygote. unbroadcast (x:: AbstractVectorOfArray , x̄)
134242 N = ndims (x̄)
135243 if length (x) == length (x̄)
0 commit comments