File tree Expand file tree Collapse file tree 2 files changed +30
-0
lines changed Expand file tree Collapse file tree 2 files changed +30
-0
lines changed Original file line number Diff line number Diff line change 9999 end
100100end
101101
102+ Zygote. @adjoint function Zygote. literal_getproperty (A:: RecursiveArrayTools.AbstractVectorOfArray , :: Val{:u} )
103+ function literal_AbstractVofA_u_adjoint (d)
104+ dA = vofa_u_adjoint (d, A)
105+ (dA, nothing )
106+ end
107+ A. u, literal_AbstractVofA_u_adjoint
108+ end
109+
110+ function vofa_u_adjoint (d, A:: RecursiveArrayTools.AbstractVectorOfArray )
111+ m = map (enumerate (d)) do (idx, d_i)
112+ isnothing (d_i) && return zero (A. u[idx])
113+ d_i
114+ end
115+ VectorOfArray (m)
116+ end
117+
118+ function vofa_u_adjoint (d, A:: RecursiveArrayTools.AbstractDiffEqArray )
119+ m = map (enumerate (d)) do (idx, d_i)
120+ isnothing (d_i) && return zero (A. u[idx])
121+ d_i
122+ end
123+ DiffEqArray (m, A. t)
124+ end
125+
102126@adjoint function literal_getproperty (A:: ArrayPartition , :: Val{:x} )
103127 function literal_ArrayPartition_x_adjoint (d)
104128 (ArrayPartition ((isnothing (d[i]) ? zero (A. x[i]) : d[i] for i in 1 : length (d)). .. ),)
Original file line number Diff line number Diff line change @@ -92,3 +92,9 @@ loss(x)
9292 VectorOfArray ([collect ((3 i): (3 i + 3 )) for i in 1 : 5 ])
9393@test Zygote. gradient (loss10, x)[1 ] == ForwardDiff. gradient (loss10, x)
9494@test Zygote. gradient (loss11, x)[1 ] == ForwardDiff. gradient (loss11, x)
95+
96+ voa = RecursiveArrayTools. VectorOfArray (fill (rand (3 ), 3 ))
97+ voa_gs, = Zygote. gradient (voa) do x
98+ sum (sum .(x. u))
99+ end
100+ @test voa_gs isa RecursiveArrayTools. VectorOfArray
You can’t perform that action at this time.
0 commit comments