Skip to content

Commit d3608c4

Browse files
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl into dg/initprob
2 parents cdaa2c7 + 7dd1cc7 commit d3608c4

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

test/mtk.jl

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ eqs = [D(D(x)) ~ σ * (y - x),
5757
u0_incorrect = [D(x) => 2.0,
5858
x => 1.0,
5959
y => 0.0,
60-
z => 0.0,
61-
w2 => 0.0,]
60+
z => 0.0]
6261

6362
p ==> 28.0,
6463
ρ => 10.0,
@@ -69,7 +68,7 @@ tspan = (0.0, 100.0)
6968
# Check that the gradients for the `solve` are the same both for an initialization
7069
# (for the algebraic variables) initialized poorly (therefore needs correction with BrownBasicInit)
7170
# and with the initialization corrected to satisfy the algebraic equation
72-
prob_incorrectu0 = ODEProblem(sys, u0_incorrect, tspan, p, jac = true)
71+
prob_incorrectu0 = ODEProblem(sys, u0_incorrect, tspan, p, jac = true, guesses = [w2 => 0.0])
7372
mtkparams_incorrectu0 = SciMLSensitivity.parameter_values(prob_incorrectu0)
7473

7574
u0_timedep = [D(x) => 2.0,
@@ -80,16 +79,25 @@ u0_timedep = [D(x) => 2.0,
8079
# this ensures that `y => t` is not applied in the adjoint equation
8180
# If the MTK init is called for the reverse, then `y0` in the backwards
8281
# pass will be extremely far off and cause an incorrect gradient
83-
prob_timedepu0 = ODEProblem(sys, u0_timedep, tspan, p, jac = true)
82+
prob_timedepu0 = ODEProblem(sys, u0_timedep, tspan, p, jac = true, guesses = [w2 => 0.0])
8483
mtkparams_timedepu0 = SciMLSensitivity.parameter_values(prob_incorrectu0)
8584

8685
u0_correct = [D(x) => 2.0,
8786
x => 1.0,
8887
y => 0.0,
8988
z => 0.0,
9089
w2 => -1.0,]
91-
prob_correctu0 = remake(prob_incorrectu0, u0 = u0_correct)
90+
prob_correctu0 = ODEProblem(sys, u0_correct, tspan, p, jac = true, guesses = [w2 => -1.0])
9291
mtkparams_correctu0 = SciMLSensitivity.parameter_values(prob_correctu0)
92+
prob_correctu0.u0[5] = -1.0
93+
94+
u0_overdetermined = [D(x) => 2.0,
95+
x => 1.0,
96+
y => 0.0,
97+
z => 0.0,
98+
w2 => -1.0,]
99+
prob_overdetermined = ODEProblem(sys, u0_overdetermined, tspan, p, jac = true)
100+
mtkparams_overdetermined = SciMLSensitivity.parameter_values(prob_overdetermined)
93101

94102
sensealg = GaussAdjoint(; autojacvec = SciMLSensitivity.ZygoteVJP())
95103

@@ -108,17 +116,24 @@ setups = [
108116
(prob_correctu0, mtkparams_correctu0, OrdinaryDiffEqCore.DefaultInit()),
109117

110118
(prob_correctu0, mtkparams_correctu0, NoInit()),
111-
(prob_correctu0, mtkparams_correctu0, nothing),
119+
(prob_correctu0, mtkparams_correctu0, nothing),
120+
121+
(prob_overdetermined, mtkparams_overdetermined, BrownFullBasicInit()),
122+
(prob_overdetermined, mtkparams_overdetermined, OrdinaryDiffEq.OrdinaryDiffEqCore.DefaultInit()),
123+
124+
(prob_overdetermined, mtkparams_overdetermined, NoInit()),
125+
(prob_overdetermined, mtkparams_overdetermined, nothing),
112126
]
113127

114128
grads = map(setups) do setup
115129
prob, ps, init = setup
116130
@show init
117-
Zygote.gradient(ps) do p
131+
u0 = prob.u0
132+
Zygote.gradient(u0, ps) do u0,p
118133
if init === nothing
119-
new_sol = solve(prob, Rodas5P(); p = p, sensealg, abstol = 1e-6, reltol = 1e-3)
134+
new_sol = solve(prob, Rodas5P(); u0 = u0, p = ps, sensealg, abstol = 1e-6, reltol = 1e-3)
120135
else
121-
new_sol = solve(prob, Rodas5P(); p = p, initializealg = init, sensealg, abstol = 1e-6, reltol = 1e-3)
136+
new_sol = solve(prob, Rodas5P(); u0 = u0, p = ps, initializealg = init, sensealg, abstol = 1e-6, reltol = 1e-3)
122137
end
123138
gt = Zygote.ChainRules.ChainRulesCore.ignore_derivatives() do
124139
@test new_sol.retcode == SciMLBase.ReturnCode.Success
@@ -131,5 +146,7 @@ grads = map(setups) do setup
131146
end
132147
end
133148

134-
grads = getproperty.(grads, (:tunable,))
135-
@test all(x grads[1] for x in grads)
149+
u0grads = getindex.(grads,1)
150+
pgrads = getproperty.(getindex.(grads, 2), (:tunable,))
151+
@test all(x u0grads[1] for x in grads)
152+
@test all(x pgrads[1] for x in grads)

0 commit comments

Comments
 (0)