diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index c03524c61b..9c85b9b324 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -654,17 +654,40 @@ end function SciMLBase.late_binding_update_u0_p( prob, sys::AbstractSystem, u0, p, t0, newu0, newp) + supports_initialization(sys) || return newu0, newp u0 === missing && return newu0, (p === missing ? copy(newp) : newp) - eltype(u0) <: Pair || return newu0, (p === missing ? copy(newp) : newp) + # non-symbolic u0 updates initials... + if !(eltype(u0) <: Pair) + # if `p` is not provided or is symbolic + p === missing || eltype(p) <: Pair || return newu0, newp + newu0 === nothing && return newu0, newp + all(is_parameter(sys, Initial(x)) for x in unknowns(sys)) || return newu0, newp + newp = p === missing ? copy(newp) : newp + initials, repack, alias = SciMLStructures.canonicalize( + SciMLStructures.Initials(), newp) + if eltype(initials) != eltype(newu0) + initials = DiffEqBase.promote_u0(initials, newu0, t0) + newp = repack(initials) + end + if length(newu0) != length(unknowns(sys)) + throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(unknowns(sys)))). Got $(typeof(newu0)) of length $(length(newu0))")) + end + setp(sys, Initial.(unknowns(sys)))(newp, newu0) + return newu0, newp + end newp = p === missing ? copy(newp) : newp newu0 = DiffEqBase.promote_u0(newu0, newp, t0) tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp) - tunables = DiffEqBase.promote_u0(tunables, newu0, t0) - newp = repack(tunables) + if eltype(tunables) != eltype(newu0) + tunables = DiffEqBase.promote_u0(tunables, newu0, t0) + newp = repack(tunables) + end initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) - initials = DiffEqBase.promote_u0(initials, newu0, t0) - newp = repack(initials) + if eltype(initials) != eltype(newu0) + initials = DiffEqBase.promote_u0(initials, newu0, t0) + newp = repack(initials) + end allsyms = all_symbols(sys) for (k, v) in u0 diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 54b847c64a..94a1f4c14f 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1476,3 +1476,18 @@ end @test sol.ps[Γ[1]] ≈ 5.0 end end + +@testset "Issue#3504: Update initials when `remake` called with non-symbolic `u0`" begin + @variables x(t) y(t) + @parameters c1 c2 + @mtkbuild sys = ODESystem([D(x) ~ -c1 * x + c2 * y, D(y) ~ c1 * x - c2 * y], t) + prob1 = ODEProblem(sys, [1.0, 2.0], (0.0, 1.0), [c1 => 1.0, c2 => 2.0]) + prob2 = remake(prob1, u0 = [2.0, 3.0]) + prob3 = remake(prob1, u0 = [2.0, 3.0], p = [c1 => 2.0]) + integ1 = init(prob1, Tsit5()) + integ2 = init(prob2, Tsit5()) + integ3 = init(prob3, Tsit5()) + @test integ2.u ≈ [2.0, 3.0] + @test integ3.u ≈ [2.0, 3.0] + @test integ3.ps[c1] ≈ 2.0 +end diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index 5cf3b4b1be..9475f24006 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -239,7 +239,7 @@ testdict = Dict([:test => 1]) prob_ = remake(prob, u0 = [1.0, 2.0, 3.0], p = [a => 1.1, b => 1.2, c => 1.3]) @test prob_.u0 == [1.0, 2.0, 3.0] - initials = unknowns(sys) .=> ones(3) + initials = unknowns(sys) .=> [1.0, 2.0, 3.0] @test prob_.p == MTKParameters(sys, [a => 1.1, b => 1.2, c => 1.3, initials...]) prob_ = remake(prob, u0 = Dict(y => 2.0), p = Dict(a => 2.0))