Skip to content

Commit 9a50122

Browse files
test: check parameter gradients are passed through
1 parent 9e6c253 commit 9a50122

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

test/downstream/observables_autodiff.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ end
5555

5656
@test gs isa NamedTuple
5757
@test isempty(setdiff(fieldnames(typeof(gs)), fieldnames(typeof(isol))))
58+
59+
# Compare gradient for parameters match from observed function
60+
# to ensure parameter gradients are passed through the observed function
61+
f = SII.observed(iprob.f.sys, w)
62+
gu0, gp = gradient(SII.state_values(iprob), SII.parameter_values(iprob)) do u0, p
63+
f(u0, p)
64+
end
65+
66+
@test gs.prob.p == gp
5867
end
5968

6069
# DAE

0 commit comments

Comments
 (0)