Skip to content

Commit 0ea8894

Browse files
feat: preserve the u0map passed to ODEProblem and use it in remake_initializeprob
1 parent 5467cac commit 0ea8894

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,23 @@ function generate_initializesystem(sys::ODESystem;
170170
nleqs = Symbolics.substitute.(nleqs, (paramsubs,))
171171
unks = [full_states; collect(values(paramsubs))]
172172
u0 = Dict(k => substitute(v, paramsubs) for (k, v) in u0)
173+
meta = InitializationSystemMetadata(Dict{Any, Any}(u0map), Dict{Any, Any}(pmap))
173174
sys_nl = NonlinearSystem(nleqs,
174175
unks,
175176
pars;
176177
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess, pmap),
177178
checks = check_units,
178179
name,
180+
metadata = meta,
179181
kwargs...)
180182
return sys_nl
181183
end
182184

185+
struct InitializationSystemMetadata
186+
u0map::Dict{Any, Any}
187+
pmap::Dict{Any, Any}
188+
end
189+
183190
function is_parameter_solvable(p, pmap, defs, guesses)
184191
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
185192
_val2 = get(defs, p, nothing)
@@ -267,6 +274,15 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
267274
!isempty(setobserved) || !isempty(setparobserved)) &&
268275
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
269276
!isempty(initialization_equations(sys)))
277+
if SciMLBase.has_initializeprob(odefn)
278+
oldsys = odefn.initializeprob.f.sys
279+
meta = get_metadata(oldsys)
280+
if meta isa InitializationSystemMetadata
281+
u0 = merge(meta.u0map, u0)
282+
p = merge(meta.pmap, p)
283+
end
284+
end
285+
270286
initprob = InitializationProblem(sys, t0, u0, p)
271287
initprobmap = getu(initprob, unknowns(sys))
272288
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]

test/initializationsystem.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,3 +770,19 @@ end
770770
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
771771
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
772772
end
773+
774+
@testset "`remake` preserves old u0map and pmap" begin
775+
@variables x(t) y(t)
776+
@parameters p
777+
@mtkbuild sys = ODESystem(
778+
[D(x) ~ x + p * y, y^2 + 4y * p^2 ~ x], t; guesses = [y => 1.0, p => 1.0])
779+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
780+
@test is_variable(prob.f.initializeprob, y)
781+
prob2 = @test_nowarn remake(prob; p = [p => 3.0]) # ensure no over/under-determined warning
782+
@test is_variable(prob.f.initializeprob, y)
783+
784+
prob = ODEProblem(sys, [y => 1.0, x => 2.0], (0.0, 1.0), [p => missing])
785+
@test is_variable(prob.f.initializeprob, p)
786+
prob2 = @test_nowarn remake(prob; u0 = [y => 0.5])
787+
@test is_variable(prob.f.initializeprob, p)
788+
end

0 commit comments

Comments
 (0)