Skip to content

Commit 5a538c3

Browse files
feat: support SciMLBase.remake_initializeprob
1 parent 8425bb8 commit 5a538c3

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,56 @@ function is_parameter_solvable(p, pmap, defs, guesses)
175175
return ((_val1 === missing || _val2 === missing) ||
176176
(_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
177177
end
178+
179+
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
180+
if (u0 === missing || !(eltype(u0) <: Pair) || isempty(u0)) &&
181+
(p === missing || !(eltype(p) <: Pair) || isempty(p))
182+
return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap,
183+
odefn.initializeprobpmap
184+
end
185+
if u0 === missing || isempty(u0)
186+
u0 = Dict()
187+
elseif !(eltype(u0) <: Pair)
188+
u0 = Dict(unknowns(sys) .=> u0)
189+
end
190+
if p === missing
191+
p = Dict()
192+
end
193+
if t0 === nothing
194+
t0 = 0.0
195+
end
196+
u0 = todict(u0)
197+
defs = defaults(sys)
198+
varmap = merge(defs, u0)
199+
varmap = canonicalize_varmap(varmap)
200+
missingvars = setdiff(unknowns(sys), collect(keys(varmap)))
201+
setobserved = filter(keys(varmap)) do var
202+
has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var))
203+
end
204+
p = todict(p)
205+
guesses = ModelingToolkit.guesses(sys)
206+
solvablepars = [par
207+
for par in parameters(sys)
208+
if is_parameter_solvable(par, p, defs, guesses)]
209+
pvarmap = merge(defs, p)
210+
setparobserved = filter(keys(pvarmap)) do var
211+
has_parameter_dependency_with_lhs(sys, var)
212+
end
213+
if (((!isempty(missingvars) || !isempty(solvablepars) ||
214+
!isempty(setobserved) || !isempty(setparobserved)) &&
215+
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
216+
!isempty(initialization_equations(sys)))
217+
initprob = InitializationProblem(sys, t0, u0, p)
218+
initprobmap = getu(initprob, unknowns(sys))
219+
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
220+
getpunknowns = getu(initprob, punknowns)
221+
setpunknowns = setp(sys, punknowns)
222+
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
223+
reqd_syms = parameter_symbols(initprob)
224+
update_initializeprob! = UpdateInitializeprob(
225+
getu(sys, reqd_syms), setu(initprob, reqd_syms))
226+
return initprob, update_initializeprob!, initprobmap, initprobpmap
227+
else
228+
return nothing, nothing, nothing, nothing
229+
end
230+
end

test/initializationsystem.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,11 @@ end
693693
@test prob.f.initializeprob.ps[p] 3.0
694694
@test init(prob, Tsit5())[x] 1.0
695695
ModelingToolkit.defaults(prob.f.sys)[p] = missing
696+
prob2 = remake(prob; u0 = [y => 1.0], p = [p => 3x])
697+
@test !is_variable(prob2.f.initializeprob, p) &&
698+
!is_parameter(prob2.f.initializeprob, p)
699+
@test init(prob2, Tsit5())[x] 0.5
700+
@test_nowarn solve(prob2, Tsit5())
696701
end
697702

698703
@testset "Equations for dependent parameters" begin

0 commit comments

Comments
 (0)