@@ -70,34 +70,34 @@ tspan = (0.0, 100.0)
7070#  and with the initialization corrected to satisfy the algebraic equation
7171prob_incorrectu0 =  ODEProblem (sys, u0_incorrect, tspan, p, jac =  true , guesses =  [w2 =>  0.0 ])
7272mtkparams_incorrectu0 =  SciMLSensitivity. parameter_values (prob_incorrectu0)
73+ test_sol =  solve (prob_incorrectu0, Rodas5P (), abstol =  1e-6 , reltol =  1e-3 )
7374
7475u0_timedep =  [D (x) =>  2.0 ,
7576    x =>  1.0 ,
7677    y =>  t,
77-     z =>  0.0 ,
78-     w2 =>  0.0 ,]
78+     z =>  0.0 ]
7979#  this ensures that `y => t` is not applied in the adjoint equation
8080#  If the MTK init is called for the reverse, then `y0` in the backwards
8181#  pass will be extremely far off and cause an incorrect gradient
8282prob_timedepu0 =  ODEProblem (sys, u0_timedep, tspan, p, jac =  true , guesses =  [w2 =>  0.0 ])
8383mtkparams_timedepu0 =  SciMLSensitivity. parameter_values (prob_incorrectu0)
84+ test_sol =  solve (prob_timedepu0, Rodas5P (), abstol =  1e-6 , reltol =  1e-3 )
8485
8586u0_correct =  [D (x) =>  2.0 ,
8687    x =>  1.0 ,
8788    y =>  0.0 ,
88-     z =>  0.0 ,
89-     w2 =>  - 1.0 ,]
89+     z =>  0.0 ,]
9090prob_correctu0 =  ODEProblem (sys, u0_correct, tspan, p, jac =  true , guesses =  [w2 =>  - 1.0 ])
9191mtkparams_correctu0 =  SciMLSensitivity. parameter_values (prob_correctu0)
92- prob_correctu0. u0[5 ] =  - 1.0 
93- 
92+ test_sol =  solve (prob_correctu0, Rodas5P (), abstol =  1e-6 , reltol =  1e-3 )
9493u0_overdetermined =  [D (x) =>  2.0 ,
9594    x =>  1.0 ,
9695    y =>  0.0 ,
9796    z =>  0.0 ,
9897    w2 =>  - 1.0 ,]
9998prob_overdetermined =  ODEProblem (sys, u0_overdetermined, tspan, p, jac =  true )
10099mtkparams_overdetermined =  SciMLSensitivity. parameter_values (prob_overdetermined)
100+ test_sol =  solve (prob_overdetermined, Rodas5P (), abstol =  1e-6 , reltol =  1e-3 )
101101
102102sensealg =  GaussAdjoint (; autojacvec =  SciMLSensitivity. ZygoteVJP ())
103103
@@ -115,25 +115,26 @@ setups = [
115115          (prob_correctu0, mtkparams_correctu0, BrownFullBasicInit ()),
116116          (prob_correctu0, mtkparams_correctu0, OrdinaryDiffEqCore. DefaultInit ()),
117117
118-            (prob_correctu0, mtkparams_correctu0, NoInit ()),  
118+         (prob_correctu0, mtkparams_correctu0, NoInit ()),
119119          (prob_correctu0, mtkparams_correctu0, nothing ),
120120
121121          (prob_overdetermined, mtkparams_overdetermined, BrownFullBasicInit ()),
122122          (prob_overdetermined, mtkparams_overdetermined, OrdinaryDiffEq. OrdinaryDiffEqCore. DefaultInit ()),
123123
124124          (prob_overdetermined, mtkparams_overdetermined, NoInit ()),
125125          (prob_overdetermined, mtkparams_overdetermined, nothing ),
126- ]
126+ ]; 
127127
128128grads =  map (setups) do  setup
129129    prob, ps, init =  setup
130130    @show  init
131131    u0 =  prob. u0
132132    Zygote. gradient (u0, ps) do  u0,p
133+         new_prob =  remake (prob, u0 =  u0, p =  p)
133134        if  init ===  nothing 
134-             new_sol =  solve (prob , Rodas5P (); u0  =  u0, p  =  ps,  sensealg, abstol =  1e-6 , reltol =  1e-3 )
135+             new_sol =  solve (new_prob , Rodas5P (); sensealg, abstol =  1e-6 , reltol =  1e-3 )
135136        else 
136-             new_sol =  solve (prob , Rodas5P (); u0  =  u0, p  =  ps,  initializealg =  init, sensealg, abstol =  1e-6 , reltol =  1e-3 )
137+             new_sol =  solve (new_prob , Rodas5P (); initializealg =  init, sensealg, abstol =  1e-6 , reltol =  1e-3 )
137138        end 
138139        gt =  Zygote. ChainRules. ChainRulesCore. ignore_derivatives () do 
139140            @test  new_sol. retcode ==  SciMLBase. ReturnCode. Success
148149
149150u0grads =  getindex .(grads,1 )
150151pgrads =  getproperty .(getindex .(grads, 2 ), (:tunable ,))
151- @test  all (x ≈  u0grads[1 ] for  x in  grads )
152- @test  all (x ≈  pgrads[1 ] for  x in  grads )
152+ @test  all (x ≈  u0grads[1 ] for  x in  u0grads )
153+ @test  all (x ≈  pgrads[1 ] for  x in  pgrads )
0 commit comments