Skip to content

Commit 74d08d9

Browse files
test: add test for adjoints for parameters in initialization
1 parent 66c0730 commit 74d08d9

File tree

1 file changed

+33
-13
lines changed

1 file changed

+33
-13
lines changed

test/downstream/observables_autodiff.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ sol = solve(prob, Tsit5())
4646
@test du == gs
4747
end
4848

49+
# @testset "AD Observable Functions for Initialization" begin
50+
# iprob = prob.f.initialization_data.initalizeprob
51+
# isol = solve(iprob)
52+
# gs, = gradient(isol) do isol
53+
# isol[w]
54+
# end
55+
56+
# end
57+
4958
# DAE
5059

5160
function create_model(; C₁ = 3e-5, C₂ = 1e-6)
@@ -84,18 +93,29 @@ end
8493
du_ = [0.2, 1.0]
8594
du = [du_ for _ in sol.u]
8695
@test gs == du
96+
97+
@testset "DAE Initialization Observable function AD" begin
98+
iprob = prob.f.initialization_data.initializeprob
99+
isol = solve(iprob)
100+
tunables, repack, _ = SS.canonicalize(SS.Tunable(), parameter_values(iprob))
101+
gs, = gradient(isol) do isol
102+
isol[sys.ampermeter.i]
103+
end
104+
gt = gs.prob.p
105+
@test findall(!iszero, gt) == [22]
106+
end
87107
end
88108

89-
# @testset "Adjoints with DAE" begin
90-
# gs_mtkp, gs_p_new = gradient(mtkparams, p_new) do p, new_tunables
91-
# new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p, new_tunables)
92-
# new_prob = remake(prob, p = new_p)
93-
# sol = solve(new_prob, Rodas4())
94-
# @show size(sol)
95-
# # mean(abs.(sol[sys.ampermeter.i] .- gt))
96-
# sum(sol[sys.ampermeter.i])
97-
# end
98-
#
99-
# @test isnothing(gs_mtkp)
100-
# @test length(gs_p_new) == length(p_new)
101-
# end
109+
@testset "Adjoints with DAE" begin
110+
gs_mtkp, gs_p_new = gradient(prob.p, prob.p.tunable) do p, new_tunables
111+
new_p = SciMLStructures.replace(SciMLStructures.Tunable(), p, new_tunables)
112+
new_prob = remake(prob, p = new_p)
113+
sol = solve(new_prob, Rodas4())
114+
@show size(sol)
115+
# mean(abs.(sol[sys.ampermeter.i] .- gt))
116+
sum(sol[sys.ampermeter.i])
117+
end
118+
119+
@test isnothing(gs_mtkp)
120+
@test length(gs_p_new) == length(p_new)
121+
end

0 commit comments

Comments
 (0)