@@ -33,15 +33,15 @@ sol = solve(prob, Tsit5())
33
33
gs, = gradient (sol) do sol
34
34
sum (sol[sys. w])
35
35
end
36
- du_ = [0 .0 , 1.0 , 1.0 , 1 .0 ]
36
+ du_ = [1 .0 , 1.0 , 1.0 , 0 .0 ]
37
37
du = [du_ for _ in sol. u]
38
38
@test du == gs. u
39
39
40
40
# Observable in a vector
41
41
gs, = gradient (sol) do sol
42
42
sum (sum .(sol[[sys. w, sys. x]]))
43
43
end
44
- du_ = [0 .0 , 1.0 , 1 .0 , 2 .0 ]
44
+ du_ = [1 .0 , 1.0 , 2 .0 , 0 .0 ]
45
45
du = [du_ for _ in sol. u]
46
46
@test du == gs. u
47
47
end
@@ -118,14 +118,19 @@ end
118
118
end
119
119
120
120
@testset " Adjoints with DAE" begin
121
- gs_mtkp, gs_p_new = gradient (prob. p, prob. p. tunable) do p, new_tunables
121
+ model = create_model ()
122
+ sys = mtkcompile (model)
123
+ prob = ODEProblem (sys, [], (0.0 , 1.0 ))
124
+ tunables, _, _ = SS. canonicalize (SS. Tunable (), prob. p)
125
+
126
+ gs_mtkp, gs_p_new = gradient (prob. p, tunables) do p, new_tunables
122
127
new_p = SS. replace (SS. Tunable (), p, new_tunables)
123
128
new_prob = remake (prob, p = new_p)
124
129
sol = solve (new_prob, Rodas4 ())
125
- mean (abs .(sol[sys. ampermeter. i] .- gt))
126
130
sum (sol[sys. ampermeter. i])
127
131
end
128
132
129
133
@test isnothing (gs_mtkp)
130
- @test length (gs_p_new) == length (p_new)
134
+ @test ! isnothing (gs_p_new)
135
+ @test length (gs_p_new) == length (tunables)
131
136
end
0 commit comments