@@ -57,8 +57,7 @@ eqs = [D(D(x)) ~ σ * (y - x),
5757u0_incorrect = [D (x) => 2.0 ,
5858 x => 1.0 ,
5959 y => 0.0 ,
60- z => 0.0 ,
61- w2 => 0.0 ,]
60+ z => 0.0 ]
6261
6362p = [σ => 28.0 ,
6463 ρ => 10.0 ,
@@ -69,7 +68,7 @@ tspan = (0.0, 100.0)
6968# Check that the gradients for the `solve` are the same both for an initialization
7069# (for the algebraic variables) initialized poorly (therefore needs correction with BrownBasicInit)
7170# and with the initialization corrected to satisfy the algebraic equation
72- prob_incorrectu0 = ODEProblem (sys, u0_incorrect, tspan, p, jac = true )
71+ prob_incorrectu0 = ODEProblem (sys, u0_incorrect, tspan, p, jac = true , guesses = [w2 => 0.0 ] )
7372mtkparams_incorrectu0 = SciMLSensitivity. parameter_values (prob_incorrectu0)
7473
7574u0_timedep = [D (x) => 2.0 ,
@@ -80,16 +79,25 @@ u0_timedep = [D(x) => 2.0,
8079# this ensures that `y => t` is not applied in the adjoint equation
8180# If the MTK init is called for the reverse, then `y0` in the backwards
8281# pass will be extremely far off and cause an incorrect gradient
83- prob_timedepu0 = ODEProblem (sys, u0_timedep, tspan, p, jac = true )
82+ prob_timedepu0 = ODEProblem (sys, u0_timedep, tspan, p, jac = true , guesses = [w2 => 0.0 ] )
8483mtkparams_timedepu0 = SciMLSensitivity. parameter_values (prob_incorrectu0)
8584
8685u0_correct = [D (x) => 2.0 ,
8786 x => 1.0 ,
8887 y => 0.0 ,
8988 z => 0.0 ,
9089 w2 => - 1.0 ,]
91- prob_correctu0 = remake (prob_incorrectu0, u0 = u0_correct )
90+ prob_correctu0 = ODEProblem (sys, u0_correct, tspan, p, jac = true , guesses = [w2 => - 1.0 ] )
9291mtkparams_correctu0 = SciMLSensitivity. parameter_values (prob_correctu0)
92+ prob_correctu0. u0[5 ] = - 1.0
93+
94+ u0_overdetermined = [D (x) => 2.0 ,
95+ x => 1.0 ,
96+ y => 0.0 ,
97+ z => 0.0 ,
98+ w2 => - 1.0 ,]
99+ prob_overdetermined = ODEProblem (sys, u0_overdetermined, tspan, p, jac = true )
100+ mtkparams_overdetermined = SciMLSensitivity. parameter_values (prob_overdetermined)
93101
94102sensealg = GaussAdjoint (; autojacvec = SciMLSensitivity. ZygoteVJP ())
95103
@@ -108,17 +116,23 @@ setups = [
108116 (prob_correctu0, mtkparams_correctu0, OrdinaryDiffEqCore. DefaultInit ()),
109117
110118 (prob_correctu0, mtkparams_correctu0, NoInit ()),
111- (prob_correctu0, mtkparams_correctu0, nothing ),
119+ (prob_correctu0, mtkparams_correctu0, nothing ),
120+
121+ (prob_overdetermined, mtkparams_overdetermined, BrownFullBasicInit ()),
122+ (prob_overdetermined, mtkparams_overdetermined, OrdinaryDiffEq. OrdinaryDiffEqCore. DefaultInit ()),
123+
124+ (prob_overdetermined, mtkparams_overdetermined, NoInit ()),
125+ (prob_overdetermined, mtkparams_overdetermined, nothing ),
112126]
113127
114128grads = map (setups) do setup
115129 prob, ps, init = setup
116130 @show init
117131 Zygote. gradient (ps) do p
118132 if init === nothing
119- new_sol = solve (prob, Rodas5P (); p = p , sensealg, abstol = 1e-6 , reltol = 1e-3 )
133+ new_sol = solve (prob, Rodas5P (); p = ps , sensealg, abstol = 1e-6 , reltol = 1e-3 )
120134 else
121- new_sol = solve (prob, Rodas5P (); p = p , initializealg = init, sensealg, abstol = 1e-6 , reltol = 1e-3 )
135+ new_sol = solve (prob, Rodas5P (); p = ps , initializealg = init, sensealg, abstol = 1e-6 , reltol = 1e-3 )
122136 end
123137 gt = Zygote. ChainRules. ChainRulesCore. ignore_derivatives () do
124138 @test new_sol. retcode == SciMLBase. ReturnCode. Success
0 commit comments