Skip to content

Commit f6aa73f

Browse files
DhairyaLGandhiAayushSabharwal
authored andcommitted
test: update for MTK v10
1 parent dc7e5bf commit f6aa73f

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

test/downstream/observables_autodiff.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ sol = solve(prob, Tsit5())
3333
gs, = gradient(sol) do sol
3434
sum(sol[sys.w])
3535
end
36-
du_ = [0.0, 1.0, 1.0, 1.0]
36+
du_ = [1.0, 1.0, 1.0, 0.0]
3737
du = [du_ for _ in sol.u]
3838
@test du == gs.u
3939

4040
# Observable in a vector
4141
gs, = gradient(sol) do sol
4242
sum(sum.(sol[[sys.w, sys.x]]))
4343
end
44-
du_ = [0.0, 1.0, 1.0, 2.0]
44+
du_ = [1.0, 1.0, 2.0, 0.0]
4545
du = [du_ for _ in sol.u]
4646
@test du == gs.u
4747
end
@@ -118,14 +118,19 @@ end
118118
end
119119

120120
@testset "Adjoints with DAE" begin
121-
gs_mtkp, gs_p_new = gradient(prob.p, prob.p.tunable) do p, new_tunables
121+
model = create_model()
122+
sys = mtkcompile(model)
123+
prob = ODEProblem(sys, [], (0.0, 1.0))
124+
tunables, _, _ = SS.canonicalize(SS.Tunable(), prob.p)
125+
126+
gs_mtkp, gs_p_new = gradient(prob.p, tunables) do p, new_tunables
122127
new_p = SS.replace(SS.Tunable(), p, new_tunables)
123128
new_prob = remake(prob, p = new_p)
124129
sol = solve(new_prob, Rodas4())
125-
mean(abs.(sol[sys.ampermeter.i] .- gt))
126130
sum(sol[sys.ampermeter.i])
127131
end
128132

129133
@test isnothing(gs_mtkp)
130-
@test length(gs_p_new) == length(p_new)
134+
@test !isnothing(gs_p_new)
135+
@test length(gs_p_new) == length(tunables)
131136
end

0 commit comments

Comments
 (0)