Skip to content

Commit 7dd1cc7

Browse files
Also test u0 gradients
1 parent b88f468 commit 7dd1cc7

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

test/mtk.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,12 @@ setups = [
128128
grads = map(setups) do setup
129129
prob, ps, init = setup
130130
@show init
131-
Zygote.gradient(ps) do p
131+
u0 = prob.u0
132+
Zygote.gradient(u0, ps) do u0,p
132133
if init === nothing
133-
new_sol = solve(prob, Rodas5P(); p = ps, sensealg, abstol = 1e-6, reltol = 1e-3)
134+
new_sol = solve(prob, Rodas5P(); u0 = u0, p = ps, sensealg, abstol = 1e-6, reltol = 1e-3)
134135
else
135-
new_sol = solve(prob, Rodas5P(); p = ps, initializealg = init, sensealg, abstol = 1e-6, reltol = 1e-3)
136+
new_sol = solve(prob, Rodas5P(); u0 = u0, p = ps, initializealg = init, sensealg, abstol = 1e-6, reltol = 1e-3)
136137
end
137138
gt = Zygote.ChainRules.ChainRulesCore.ignore_derivatives() do
138139
@test new_sol.retcode == SciMLBase.ReturnCode.Success
@@ -145,5 +146,7 @@ grads = map(setups) do setup
145146
end
146147
end
147148

148-
grads = getproperty.(grads, (:tunable,))
149-
@test all(x grads[1] for x in grads)
149+
u0grads = getindex.(grads,1)
150+
pgrads = getproperty.(getindex.(grads, 2), (:tunable,))
151+
@test all(x u0grads[1] for x in grads)
152+
@test all(x pgrads[1] for x in grads)

0 commit comments

Comments
 (0)