Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ function terminate!(integrator::ODEIntegrator, retcode = ReturnCode.Terminated)
integrator.opts.tstops.valtree = typeof(integrator.opts.tstops.valtree)()
end

const EMPTY_ARRAY_OF_PAIRS = Pair[]

DiffEqBase.has_reinit(integrator::ODEIntegrator) = true
function DiffEqBase.reinit!(integrator::ODEIntegrator, u0 = integrator.sol.prob.u0;
t0 = integrator.sol.prob.tspan[1],
Expand All @@ -335,6 +337,23 @@ function DiffEqBase.reinit!(integrator::ODEIntegrator, u0 = integrator.sol.prob.
reinit_callbacks = true, initialize_save = true,
reinit_cache = true,
reinit_retcode = true)
if reinit_dae && SciMLBase.has_initializeprob(integrator.sol.prob.f)
# This is `remake` infrastructure. `reinit!` is somewhat like `remake` for
# integrators, so we reuse some of the same pieces. If we pass `integrator.p`
# for `p`, it means we don't want to change it. If we pass `missing`, this
# function may (correctly) assume `newp` aliases `prob.p` and copy it, which we
# want to avoid. So we pass an empty array of pairs to make it think this is
# a symbolic `remake` and it can modify `newp` inplace. The array of pairs is a
# const global to avoid allocating every time this function is called.
u0, newp = SciMLBase.late_binding_update_u0_p(integrator.sol.prob, u0,
EMPTY_ARRAY_OF_PAIRS, t0, u0, integrator.p)
if newp !== integrator.p
integrator.p = newp
sol = integrator.sol
@reset sol.prob.p = newp
integrator.sol = sol
end
end
if isinplace(integrator.sol.prob)
recursivecopy!(integrator.u, u0)
recursivecopy!(integrator.uprev, integrator.u)
Expand Down
17 changes: 17 additions & 0 deletions test/interface/dae_initialize_integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,20 @@ sol = solve(prob, Rodas5P(), dt = 1e-10)
@test sol[1] == [1.0]
@test sol[2] ≈ [0.9999999998]
@test sol[end] ≈ [-1.0]

@testset "`reinit!` updates initial parameters" begin
# https://github.com/SciML/ModelingToolkit.jl/issues/3451
# https://github.com/SciML/ModelingToolkit.jl/issues/3504
@variables x(t) y(t)
@parameters c1 c2
@mtkbuild sys = ODESystem([D(x) ~ -c1 * x + c2 * y, D(y) ~ c1 * x - c2 * y], t)
prob = ODEProblem(sys, [1.0, 2.0], (0.0, 1.0), [c1 => 1.0, c2 => 2.0])
@test prob.ps[Initial(x)] ≈ 1.0
@test prob.ps[Initial(y)] ≈ 2.0
integ = init(prob, Tsit5())
@test integ.ps[Initial(x)] ≈ 1.0
@test integ.ps[Initial(y)] ≈ 2.0
reinit!(integ, [2.0, 3.0])
@test integ.ps[Initial(x)] ≈ 2.0
@test integ.ps[Initial(y)] ≈ 3.0
Comment on lines +92 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AayushSabharwal shouldn't this also test integ.u? That's what was raised in SciML/ModelingToolkit.jl#3451

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be integ.u after DAE init is ran here: https://github.com/SciML/OrdinaryDiffEq.jl/pull/2658/files#diff-979c91992be010582f7db3acfd7bc376632fb3965e9f19ddddc0647986f3de4cR431 . So yeah it probably would be good to add a test that is correctly handled.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll PR soon

end
Loading