@@ -42,3 +42,45 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol
4242 end
4343 A. x,literal_ArrayPartition_x_adjoint
4444end
45+
46+ ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i)
47+ function AbstractVectorOfArray_getindex_adjoint (Δ)
48+ Δ′ = [ (i == j ? Δ : zero (x)) for (x,j) in zip (VA. u, 1 : length (VA))]
49+ (Δ′,nothing )
50+ end
51+ VA[i],AbstractVectorOfArray_getindex_adjoint
52+ end
53+
54+ ZygoteRules. @adjoint function getindex (VA:: AbstractVectorOfArray , i, j... )
55+ function AbstractVectorOfArray_getindex_adjoint (Δ)
56+ Δ′ = zero (VA)
57+ Δ′[i,j... ] = Δ
58+ (Δ′, i,map (_ -> nothing , j)... )
59+ end
60+ VA[i,j... ],AbstractVectorOfArray_getindex_adjoint
61+ end
62+
63+ ZygoteRules. @adjoint function ArrayPartition (x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
64+ function ArrayPartition_adjoint (_y)
65+ y = Array (_y)
66+ starts = vcat (0 ,cumsum (reduce (vcat,length .(x))))
67+ ntuple (i -> reshape (y[starts[i]+ 1 : starts[i+ 1 ]], size (x[i])), length (x)), nothing
68+ end
69+
70+ ArrayPartition (x, Val{copy_x}), ArrayPartition_adjoint
71+ end
72+
73+ ZygoteRules. @adjoint function VectorOfArray (u)
74+ VectorOfArray (u),y -> ([y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]],)
75+ end
76+
77+ ZygoteRules. @adjoint function DiffEqArray (u,t)
78+ DiffEqArray (u,t),y -> ([y[ntuple (x-> Colon (),ndims (y)- 1 )... ,i] for i in 1 : size (y)[end ]],nothing )
79+ end
80+
81+ ZygoteRules. @adjoint function ZygoteRules. literal_getproperty (A:: ArrayPartition , :: Val{:x} )
82+ function literal_ArrayPartition_x_adjoint (d)
83+ (ArrayPartition ((isnothing (d[i]) ? zero (A. x[i]) : d[i] for i in 1 : length (d)). .. ),)
84+ end
85+ A. x,literal_ArrayPartition_x_adjoint
86+ end
0 commit comments