@@ -36,39 +36,44 @@ sol = solve(prob, Tsit5())
3636tunables, repack, _ = SS. canonicalize (SS. Tunable (), parameter_values (prob))
3737
3838@testset " Adjoint through Parameter Initialization" begin
39- @testset " Forward Mode" begin
40- gs_fwd, = Zygote. gradient (tunables) do tunables
41- new_prob = remake (prob; p = repack (tunables))
42- iprob = new_prob. f. initialization_data. initializeprob
43- isol = solve (iprob)
44- isol[w]
39+ fn = function (tunables)
40+ new_prob = remake (prob; p = repack (tunables))
41+ initdata = new_prob. f. initialization_data
42+ iprob = initdata. initializeprob
43+ iprob = if initdata. is_update_oop === Val (true )
44+ initdata. update_initializeprob! (iprob, new_prob)
45+ else
46+ initdata. update_initializeprob! (iprob, new_prob)
47+ iprob
4548 end
49+ isol = solve (iprob)
50+ isol[w]
51+ end
52+ @testset " Forward Mode" begin
53+ gs_fwd, = Zygote. gradient (fn, tunables)
4654 @test any (! iszero, gs_fwd)
4755 end
4856
4957 @testset " Reverse Mode" begin
5058 sensealg = SciMLSensitivity. SteadyStateAdjoint (autojacvec = SciMLSensitivity. ReverseDiffVJP ())
51- gs_reverse, = Zygote. gradient (tunables) do tunables
52- new_prob = remake (prob; p = repack (tunables))
53- iprob = new_prob. f. initialization_data. initializeprob
54- isol = solve (iprob; sensealg)
55- isol[w]
56- end
59+ gs_reverse, = Zygote. gradient (fn, tunables)
5760 @test any (! iszero, gs_reverse)
5861
5962 sensealg = SciMLSensitivity. SteadyStateAdjoint (autojacvec = SciMLSensitivity. ZygoteVJP ())
60- gs_zyg, = Zygote. gradient (tunables) do tunables
61- new_prob = remake (prob; p = repack (tunables))
62- iprob = new_prob. f. initialization_data. initializeprob
63- isol = solve (iprob; sensealg)
64- isol[w]
65- end
63+ gs_zyg, = Zygote. gradient (fn, tunables)
6664 @test any (! iszero, gs_zyg)
6765
6866 sensealg = SciMLSensitivity. SteadyStateAdjoint (autojacvec = SciMLSensitivity. ZygoteVJP ())
6967 gs_obs, = Zygote. gradient (tunables) do tunables
7068 new_prob = remake (prob; p = repack (tunables))
71- iprob = new_prob. f. initialization_data. initializeprob
69+ initdata = new_prob. f. initialization_data
70+ iprob = initdata. initializeprob
71+ iprob = if initdata. is_update_oop === Val (true )
72+ initdata. update_initializeprob! (iprob, new_prob)
73+ else
74+ initdata. update_initializeprob! (iprob, new_prob)
75+ iprob
76+ end
7277 isol = solve (iprob; sensealg)
7378 obsfn = Zygote. ignore () do
7479 SII. observed (isol. prob. f. sys, w). f_oop
0 commit comments