Skip to content

Commit bd964a8

Browse files
Fix last few tests
1 parent e643273 commit bd964a8

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/concrete_solve.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,13 +1301,15 @@ function DiffEqBase._concrete_solve_adjoint(
13011301
SciMLBase.FullSpecialize
13021302
}(_f,
13031303
_g),
1304+
g = _g,
13041305
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p),
13051306
tspan = _tspan, callback = nothing)
13061307
else
13071308
_prob = remake(prob,
13081309
f = ArrayInterface.parameterless_type(prob.f){false,
13091310
SciMLBase.FullSpecialize
13101311
}(_f),
1312+
g = _g,
13111313
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p),
13121314
tspan = _tspan, callback = nothing)
13131315
end
@@ -1335,13 +1337,15 @@ function DiffEqBase._concrete_solve_adjoint(
13351337
SciMLBase.FullSpecialize
13361338
}(_f,
13371339
_g),
1340+
g = _g,
13381341
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p),
13391342
tspan = _tspan, callback = nothing)
13401343
else
13411344
_prob = remake(prob,
13421345
f = ArrayInterface.parameterless_type(prob.f){false,
13431346
SciMLBase.FullSpecialize
13441347
}(_f),
1348+
g = _g,
13451349
u0 = _u0, p = SciMLStructures.replace(Tunable(), p, _p),
13461350
tspan = _tspan, callback = nothing)
13471351
end

test/sparse_adjoint.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ p = collect(1.0:n)
1515
u0 = ones(n)
1616
tspan = [0.0, 1]
1717
odef = ODEFunction(foop; jac = jac, jac_prototype = jac(u0, p, 0.0), paramjac = paramjac)
18-
function g_helper(p; alg = Rosenbrock23(linsolve = LUFactorization()))
18+
function g_helper(p; alg = Rosenbrock23(linsolve = QRFactorization()))
1919
prob = ODEProblem(odef, u0, tspan, p)
2020
soln = Array(solve(prob, alg; u0 = prob.u0, p = prob.p, abstol = 1e-4, reltol = 1e-4,
2121
sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP())))[:, end]
@@ -35,7 +35,7 @@ end
3535
@test isapprox(exp.(p), g_helper(p; alg = ImplicitEuler(linsolve = LUFactorization()));
3636
atol = 1e-1, rtol = 1e-1)
3737
@test isapprox(exp.(p),
38-
Zygote.gradient(p -> g(p; alg = ImplicitEuler(linsolve = LUFactorization())),
38+
Zygote.gradient(p -> g(p; alg = ImplicitEuler(linsolve = QRFactorization())),
3939
p)[1]; atol = 1e-1, rtol = 1e-1)
4040
@test isapprox(
4141
exp.(p), g_helper(p; alg = ImplicitEuler(linsolve = UMFPACKFactorization()));

0 commit comments

Comments
 (0)