@@ -128,11 +128,12 @@ setups = [
128128grads = map (setups) do setup
129129 prob, ps, init = setup
130130 @show init
131- Zygote. gradient (ps) do p
131+ u0 = prob. u0
132+ Zygote. gradient (u0, ps) do u0,p
132133 if init === nothing
133- new_sol = solve (prob, Rodas5P (); p = ps, sensealg, abstol = 1e-6 , reltol = 1e-3 )
134+ new_sol = solve (prob, Rodas5P (); u0 = u0, p = ps, sensealg, abstol = 1e-6 , reltol = 1e-3 )
134135 else
135- new_sol = solve (prob, Rodas5P (); p = ps, initializealg = init, sensealg, abstol = 1e-6 , reltol = 1e-3 )
136+ new_sol = solve (prob, Rodas5P (); u0 = u0, p = ps, initializealg = init, sensealg, abstol = 1e-6 , reltol = 1e-3 )
136137 end
137138 gt = Zygote. ChainRules. ChainRulesCore. ignore_derivatives () do
138139 @test new_sol. retcode == SciMLBase. ReturnCode. Success
@@ -145,5 +146,7 @@ grads = map(setups) do setup
145146 end
146147end
147148
148- grads = getproperty .(grads, (:tunable ,))
149- @test all (x ≈ grads[1 ] for x in grads)
149+ u0grads = getindex .(grads,1 )
150+ pgrads = getproperty .(getindex .(grads, 2 ), (:tunable ,))
151+ @test all (x ≈ u0grads[1 ] for x in grads)
152+ @test all (x ≈ pgrads[1 ] for x in grads)
0 commit comments