Skip to content

Commit acef9a3

Browse files
test: fix parameter initialization test
1 parent 4579890 commit acef9a3

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

test/parameter_initialization.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,39 +36,44 @@ sol = solve(prob, Tsit5())
3636
tunables, repack, _ = SS.canonicalize(SS.Tunable(), parameter_values(prob))
3737

3838
@testset "Adjoint through Parameter Initialization" begin
39-
@testset "Forward Mode" begin
40-
gs_fwd, = Zygote.gradient(tunables) do tunables
41-
new_prob = remake(prob; p = repack(tunables))
42-
iprob = new_prob.f.initialization_data.initializeprob
43-
isol = solve(iprob)
44-
isol[w]
39+
fn = function (tunables)
40+
new_prob = remake(prob; p = repack(tunables))
41+
initdata = new_prob.f.initialization_data
42+
iprob = initdata.initializeprob
43+
iprob = if initdata.is_update_oop === Val(true)
44+
initdata.update_initializeprob!(iprob, new_prob)
45+
else
46+
initdata.update_initializeprob!(iprob, new_prob)
47+
iprob
4548
end
49+
isol = solve(iprob)
50+
isol[w]
51+
end
52+
@testset "Forward Mode" begin
53+
gs_fwd, = Zygote.gradient(fn, tunables)
4654
@test any(!iszero, gs_fwd)
4755
end
4856

4957
@testset "Reverse Mode" begin
5058
sensealg = SciMLSensitivity.SteadyStateAdjoint(autojacvec = SciMLSensitivity.ReverseDiffVJP())
51-
gs_reverse, = Zygote.gradient(tunables) do tunables
52-
new_prob = remake(prob; p = repack(tunables))
53-
iprob = new_prob.f.initialization_data.initializeprob
54-
isol = solve(iprob; sensealg)
55-
isol[w]
56-
end
59+
gs_reverse, = Zygote.gradient(fn, tunables)
5760
@test any(!iszero, gs_reverse)
5861

5962
sensealg = SciMLSensitivity.SteadyStateAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP())
60-
gs_zyg, = Zygote.gradient(tunables) do tunables
61-
new_prob = remake(prob; p = repack(tunables))
62-
iprob = new_prob.f.initialization_data.initializeprob
63-
isol = solve(iprob; sensealg)
64-
isol[w]
65-
end
63+
gs_zyg, = Zygote.gradient(fn, tunables)
6664
@test any(!iszero, gs_zyg)
6765

6866
sensealg = SciMLSensitivity.SteadyStateAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP())
6967
gs_obs, = Zygote.gradient(tunables) do tunables
7068
new_prob = remake(prob; p = repack(tunables))
71-
iprob = new_prob.f.initialization_data.initializeprob
69+
initdata = new_prob.f.initialization_data
70+
iprob = initdata.initializeprob
71+
iprob = if initdata.is_update_oop === Val(true)
72+
initdata.update_initializeprob!(iprob, new_prob)
73+
else
74+
initdata.update_initializeprob!(iprob, new_prob)
75+
iprob
76+
end
7277
isol = solve(iprob; sensealg)
7378
obsfn = Zygote.ignore() do
7479
SII.observed(isol.prob.f.sys, w).f_oop

0 commit comments

Comments
 (0)