Skip to content

Commit 70ae7fb

Browse files
Merge pull request #808 from AayushSabharwal/as/fix-tests
fix: use `split = false` system for remake autodiff tests
2 parents b064f2b + 6184c69 commit 70ae7fb

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

ext/SciMLBaseZygoteExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,12 @@ end
106106
if is_observed(VA, sym)
107107
f = observed(VA, sym)
108108
p = parameter_values(VA)
109-
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
109+
tunables, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
110110
u = state_values(VA)
111111
t = current_time(VA)
112112
y, back = Zygote.pullback(u, tunables) do u, tunables
113-
f.(u, Ref(tunables), t)
113+
_p = repack(tunables)
114+
f.(u, Ref(_p), t)
114115
end
115116
gs = back(Δ)
116117
(gs[1], nothing)

test/downstream/remake_autodiff.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ function lotka_volterra(; name = name)
1414
end
1515

1616
@named lotka_volterra_sys = lotka_volterra()
17-
lotka_volterra_sys = structural_simplify(lotka_volterra_sys)
17+
lotka_volterra_sys = structural_simplify(lotka_volterra_sys, split = false)
1818
prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), [])
1919
sol = solve(prob, Tsit5(), reltol = 1e-6, abstol = 1e-6)
20-
u0 = [1.0 1.0]
21-
p = [1.5 1.0 1.0 1.0]
20+
u0 = [1.0, 1.0]
21+
p = [1.5, 1.0, 1.0, 1.0]
2222

2323
function sum_of_solution(u0, p)
2424
_prob = remake(prob, u0 = u0, p = p)

0 commit comments

Comments
 (0)