From 5b81a118e35c662f04b8aead8b401957a0637fec Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 3 Apr 2025 01:00:35 +0530 Subject: [PATCH 1/7] fix: update initials with non-symbolic `u0` in `remake` --- src/systems/nonlinear/initializesystem.jl | 12 +++++++++++- test/initializationsystem.jl | 11 +++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index c03524c61b..23573673b0 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -655,7 +655,17 @@ end function SciMLBase.late_binding_update_u0_p( prob, sys::AbstractSystem, u0, p, t0, newu0, newp) u0 === missing && return newu0, (p === missing ? copy(newp) : newp) - eltype(u0) <: Pair || return newu0, (p === missing ? copy(newp) : newp) + if !(eltype(u0) <: Pair) + p === missing || return newu0, newp + newu0 === nothing && return newu0, newp + newp = p === missing ? copy(newp) : newp + initials, repack, alias = SciMLStructures.canonicalize( + SciMLStructures.Initials(), newp) + initials = DiffEqBase.promote_u0(initials, newu0, t0) + newp = repack(initials) + setp(sys, Initial.(unknowns(sys)))(newp, newu0) + return newu0, newp + end newp = p === missing ? copy(newp) : newp newu0 = DiffEqBase.promote_u0(newu0, newp, t0) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 54b847c64a..1120394b01 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1476,3 +1476,14 @@ end @test sol.ps[Γ[1]] ≈ 5.0 end end + +@testset "Issue#3504: Update initials when `remake` called with non-symbolic `u0`" begin + @variables x(t) y(t) + @parameters c1 c2 + @mtkbuild sys = ODESystem([D(x) ~ -c1 * x + c2 * y, D(y) ~ c1 * x - c2 * y], t) + prob1 = ODEProblem(sys, [1.0, 2.0], (0.0, 1.0), [c1 => 1.0, c2 => 2.0]) + prob2 = remake(prob1, u0 = [2.0, 3.0]) + integ1 = init(prob1, Tsit5()) + integ2 = init(prob2, Tsit5()) + @test integ2.u ≈ [2.0, 3.0] +end From 68e06054b83f63af76ed4f0a6b9ea3f56cb559dd Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 3 Apr 2025 01:10:43 +0530 Subject: [PATCH 2/7] fix: update initials in `remake` with non-symbolic `u0` and symbolic `p` --- src/systems/nonlinear/initializesystem.jl | 4 +++- test/initializationsystem.jl | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 23573673b0..daeef1def6 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -655,8 +655,10 @@ end function SciMLBase.late_binding_update_u0_p( prob, sys::AbstractSystem, u0, p, t0, newu0, newp) u0 === missing && return newu0, (p === missing ? copy(newp) : newp) + # non-symbolic u0 updates initials... if !(eltype(u0) <: Pair) - p === missing || return newu0, newp + # if `p` is not provided or is symbolic + p === missing || eltype(p) <: Pair || return newu0, newp newu0 === nothing && return newu0, newp newp = p === missing ? copy(newp) : newp initials, repack, alias = SciMLStructures.canonicalize( diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 1120394b01..94a1f4c14f 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1483,7 +1483,11 @@ end @mtkbuild sys = ODESystem([D(x) ~ -c1 * x + c2 * y, D(y) ~ c1 * x - c2 * y], t) prob1 = ODEProblem(sys, [1.0, 2.0], (0.0, 1.0), [c1 => 1.0, c2 => 2.0]) prob2 = remake(prob1, u0 = [2.0, 3.0]) + prob3 = remake(prob1, u0 = [2.0, 3.0], p = [c1 => 2.0]) integ1 = init(prob1, Tsit5()) integ2 = init(prob2, Tsit5()) + integ3 = init(prob3, Tsit5()) @test integ2.u ≈ [2.0, 3.0] + @test integ3.u ≈ [2.0, 3.0] + @test integ3.ps[c1] ≈ 2.0 end From c70f1015cd4be94c858904fce407424edecffe6c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 3 Apr 2025 11:42:25 +0530 Subject: [PATCH 3/7] fix: early exit `late_binding_update_u0_p` if system doesn't support initialization --- src/systems/nonlinear/initializesystem.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index daeef1def6..52eb886bf9 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -654,6 +654,7 @@ end function SciMLBase.late_binding_update_u0_p( prob, sys::AbstractSystem, u0, p, t0, newu0, newp) + supports_initialization(sys) || return newu0, newp u0 === missing && return newu0, (p === missing ? copy(newp) : newp) # non-symbolic u0 updates initials... if !(eltype(u0) <: Pair) From 69509276a2954665fab4a40c9496e5b64c196ef8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 3 Apr 2025 11:42:39 +0530 Subject: [PATCH 4/7] test: update test to account for new new initial propogation --- test/nonlinearsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index 5cf3b4b1be..9475f24006 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -239,7 +239,7 @@ testdict = Dict([:test => 1]) prob_ = remake(prob, u0 = [1.0, 2.0, 3.0], p = [a => 1.1, b => 1.2, c => 1.3]) @test prob_.u0 == [1.0, 2.0, 3.0] - initials = unknowns(sys) .=> ones(3) + initials = unknowns(sys) .=> [1.0, 2.0, 3.0] @test prob_.p == MTKParameters(sys, [a => 1.1, b => 1.2, c => 1.3, initials...]) prob_ = remake(prob, u0 = Dict(y => 2.0), p = Dict(a => 2.0)) From a9ca3e7dca3cfab4e38ec2805a3647ee3b491682 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 3 Apr 2025 12:17:23 +0530 Subject: [PATCH 5/7] fix: remove unnecessary copies in `late_binding_update_u0_p` --- src/systems/nonlinear/initializesystem.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 52eb886bf9..99fc02e5e7 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -664,8 +664,10 @@ function SciMLBase.late_binding_update_u0_p( newp = p === missing ? copy(newp) : newp initials, repack, alias = SciMLStructures.canonicalize( SciMLStructures.Initials(), newp) - initials = DiffEqBase.promote_u0(initials, newu0, t0) - newp = repack(initials) + if eltype(initials) != eltype(newu0) + initials = DiffEqBase.promote_u0(initials, newu0, t0) + newp = repack(initials) + end setp(sys, Initial.(unknowns(sys)))(newp, newu0) return newu0, newp end @@ -673,11 +675,15 @@ function SciMLBase.late_binding_update_u0_p( newp = p === missing ? copy(newp) : newp newu0 = DiffEqBase.promote_u0(newu0, newp, t0) tunables, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Tunable(), newp) - tunables = DiffEqBase.promote_u0(tunables, newu0, t0) - newp = repack(tunables) + if eltype(tunables) != eltype(newu0) + tunables = DiffEqBase.promote_u0(tunables, newu0, t0) + newp = repack(tunables) + end initials, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) - initials = DiffEqBase.promote_u0(initials, newu0, t0) - newp = repack(initials) + if eltype(initials) != eltype(newu0) + initials = DiffEqBase.promote_u0(initials, newu0, t0) + newp = repack(initials) + end allsyms = all_symbols(sys) for (k, v) in u0 From 216aa73e5c0ae76e4642f29aa7de22b1ff75d92a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 3 Apr 2025 16:03:16 +0530 Subject: [PATCH 6/7] fix: handle edge case in `late_binding_update_u0_p` when `Initial` parameters don't exist --- src/systems/nonlinear/initializesystem.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 99fc02e5e7..2283e7a1d4 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -661,6 +661,7 @@ function SciMLBase.late_binding_update_u0_p( # if `p` is not provided or is symbolic p === missing || eltype(p) <: Pair || return newu0, newp newu0 === nothing && return newu0, newp + all(is_parameter(sys, Initial(x)) for x in unknowns(sys)) || return newu0, newp newp = p === missing ? copy(newp) : newp initials, repack, alias = SciMLStructures.canonicalize( SciMLStructures.Initials(), newp) From 01a43bec2f4bf1e1f2af0444dcc7abfdedaa6d77 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 3 Apr 2025 17:41:43 +0530 Subject: [PATCH 7/7] feat: error in `late_binding_update_u0_p` if `newu0` is of incorrect length --- src/systems/nonlinear/initializesystem.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 2283e7a1d4..9c85b9b324 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -669,6 +669,9 @@ function SciMLBase.late_binding_update_u0_p( initials = DiffEqBase.promote_u0(initials, newu0, t0) newp = repack(initials) end + if length(newu0) != length(unknowns(sys)) + throw(ArgumentError("Expected `newu0` to be of same length as unknowns ($(length(unknowns(sys)))). Got $(typeof(newu0)) of length $(length(newu0))")) + end setp(sys, Initial.(unknowns(sys)))(newp, newu0) return newu0, newp end