@@ -33,15 +33,15 @@ sol = solve(prob, Tsit5())
3333 gs, = gradient (sol) do sol
3434 sum (sol[sys. w])
3535 end
36- du_ = [0 .0 , 1.0 , 1.0 , 1 .0 ]
36+ du_ = [1 .0 , 1.0 , 1.0 , 0 .0 ]
3737 du = [du_ for _ in sol. u]
3838 @test du == gs. u
3939
4040 # Observable in a vector
4141 gs, = gradient (sol) do sol
4242 sum (sum .(sol[[sys. w, sys. x]]))
4343 end
44- du_ = [0 .0 , 1.0 , 1 .0 , 2 .0 ]
44+ du_ = [1 .0 , 1.0 , 2 .0 , 0 .0 ]
4545 du = [du_ for _ in sol. u]
4646 @test du == gs. u
4747end
@@ -118,14 +118,19 @@ end
118118end
119119
120120@testset " Adjoints with DAE" begin
121- gs_mtkp, gs_p_new = gradient (prob. p, prob. p. tunable) do p, new_tunables
121+ model = create_model ()
122+ sys = mtkcompile (model)
123+ prob = ODEProblem (sys, [], (0.0 , 1.0 ))
124+ tunables, _, _ = SS. canonicalize (SS. Tunable (), prob. p)
125+
126+ gs_mtkp, gs_p_new = gradient (prob. p, tunables) do p, new_tunables
122127 new_p = SS. replace (SS. Tunable (), p, new_tunables)
123128 new_prob = remake (prob, p = new_p)
124129 sol = solve (new_prob, Rodas4 ())
125- mean (abs .(sol[sys. ampermeter. i] .- gt))
126130 sum (sol[sys. ampermeter. i])
127131 end
128132
129133 @test isnothing (gs_mtkp)
130- @test length (gs_p_new) == length (p_new)
134+ @test ! isnothing (gs_p_new)
135+ @test length (gs_p_new) == length (tunables)
131136end
0 commit comments