diff --git a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl index 306333435b..b1a014ef71 100644 --- a/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl +++ b/lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl @@ -321,6 +321,8 @@ function terminate!(integrator::ODEIntegrator, retcode = ReturnCode.Terminated) integrator.opts.tstops.valtree = typeof(integrator.opts.tstops.valtree)() end +const EMPTY_ARRAY_OF_PAIRS = Pair[] + DiffEqBase.has_reinit(integrator::ODEIntegrator) = true function DiffEqBase.reinit!(integrator::ODEIntegrator, u0 = integrator.sol.prob.u0; t0 = integrator.sol.prob.tspan[1], @@ -335,6 +337,23 @@ function DiffEqBase.reinit!(integrator::ODEIntegrator, u0 = integrator.sol.prob. reinit_callbacks = true, initialize_save = true, reinit_cache = true, reinit_retcode = true) + if reinit_dae && SciMLBase.has_initializeprob(integrator.sol.prob.f) + # This is `remake` infrastructure. `reinit!` is somewhat like `remake` for + # integrators, so we reuse some of the same pieces. If we pass `integrator.p` + # for `p`, it means we don't want to change it. If we pass `missing`, this + # function may (correctly) assume `newp` aliases `prob.p` and copy it, which we + # want to avoid. So we pass an empty array of pairs to make it think this is + # a symbolic `remake` and it can modify `newp` inplace. The array of pairs is a + # const global to avoid allocating every time this function is called. + u0, newp = SciMLBase.late_binding_update_u0_p(integrator.sol.prob, u0, + EMPTY_ARRAY_OF_PAIRS, t0, u0, integrator.p) + if newp !== integrator.p + integrator.p = newp + sol = integrator.sol + @reset sol.prob.p = newp + integrator.sol = sol + end + end if isinplace(integrator.sol.prob) recursivecopy!(integrator.u, u0) recursivecopy!(integrator.uprev, integrator.u) diff --git a/test/interface/dae_initialize_integration.jl b/test/interface/dae_initialize_integration.jl index 85a119d859..4b223703e6 100644 --- a/test/interface/dae_initialize_integration.jl +++ b/test/interface/dae_initialize_integration.jl @@ -76,3 +76,20 @@ sol = solve(prob, Rodas5P(), dt = 1e-10) @test sol[1] == [1.0] @test sol[2] ≈ [0.9999999998] @test sol[end] ≈ [-1.0] + +@testset "`reinit!` updates initial parameters" begin + # https://github.com/SciML/ModelingToolkit.jl/issues/3451 + # https://github.com/SciML/ModelingToolkit.jl/issues/3504 + @variables x(t) y(t) + @parameters c1 c2 + @mtkbuild sys = ODESystem([D(x) ~ -c1 * x + c2 * y, D(y) ~ c1 * x - c2 * y], t) + prob = ODEProblem(sys, [1.0, 2.0], (0.0, 1.0), [c1 => 1.0, c2 => 2.0]) + @test prob.ps[Initial(x)] ≈ 1.0 + @test prob.ps[Initial(y)] ≈ 2.0 + integ = init(prob, Tsit5()) + @test integ.ps[Initial(x)] ≈ 1.0 + @test integ.ps[Initial(y)] ≈ 2.0 + reinit!(integ, [2.0, 3.0]) + @test integ.ps[Initial(x)] ≈ 2.0 + @test integ.ps[Initial(y)] ≈ 3.0 +end