Skip to content

Commit 9e6c253

Browse files
test: update observable tests to use new return type in gradient
1 parent c5f82f9 commit 9e6c253

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

test/downstream/observables_autodiff.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ p = [σ => 28.0,
2626
β => 8 / 3]
2727

2828
tspan = (0.0, 100.0)
29-
prob = ODEProblem(sys, u0, tspan, p, jac = true)
29+
prob = ODEProblem(sys, u0, tspan, p)
3030
sol = solve(prob, Tsit5())
3131

3232
@testset "AutoDiff Observable Functions" begin
@@ -35,25 +35,27 @@ sol = solve(prob, Tsit5())
3535
end
3636
du_ = [0.0, 1.0, 1.0, 1.0]
3737
du = [du_ for _ in sol.u]
38-
@test du == gs
38+
@test du == gs.u
3939

4040
# Observable in a vector
4141
gs, = gradient(sol) do sol
4242
sum(sum.(sol[[sys.w, sys.x]]))
4343
end
4444
du_ = [0.0, 1.0, 1.0, 2.0]
4545
du = [du_ for _ in sol.u]
46-
@test du == gs
46+
@test du == gs.u
4747
end
4848

49-
# @testset "AD Observable Functions for Initialization" begin
50-
# iprob = prob.f.initialization_data.initalizeprob
51-
# isol = solve(iprob)
52-
# gs, = gradient(isol) do isol
53-
# isol[w]
54-
# end
49+
@testset "AD Observable Functions for Initialization" begin
50+
iprob = prob.f.initialization_data.initializeprob
51+
isol = solve(iprob)
52+
gs, = Zygote.gradient(isol) do isol
53+
isol[w]
54+
end
5555

56-
# end
56+
@test gs isa NamedTuple
57+
@test isempty(setdiff(fieldnames(typeof(gs)), fieldnames(typeof(isol))))
58+
end
5759

5860
# DAE
5961

@@ -92,7 +94,7 @@ end
9294
end
9395
du_ = [0.2, 1.0]
9496
du = [du_ for _ in sol.u]
95-
@test gs == du
97+
@test gs.u == du
9698

9799
@testset "DAE Initialization Observable function AD" begin
98100
iprob = prob.f.initialization_data.initializeprob

0 commit comments

Comments
 (0)