Skip to content

Commit 90df3c3

Browse files
refactor: improve remake_initializeprob
1 parent afc0226 commit 90df3c3

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,13 @@ end
178178
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
179179
if (u0 === missing || !(eltype(u0) <: Pair) || isempty(u0)) &&
180180
(p === missing || !(eltype(p) <: Pair) || isempty(p))
181-
return odefn.initializeprob, odefn.initializeprobmap, odefn.initializeprobpmap
181+
return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap,
182+
odefn.initializeprobpmap
182183
end
183-
if u0 === missing
184+
if u0 === missing || isempty(u0)
184185
u0 = Dict()
186+
elseif !(eltype(u0) <: Pair)
187+
u0 = Dict(unknowns(sys) .=> u0)
185188
end
186189
if p === missing
187190
p = Dict()
@@ -190,15 +193,33 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
190193
t0 = 0.0
191194
end
192195
u0 = todict(u0)
196+
defs = defaults(sys)
197+
varmap = merge(defs, u0)
198+
varmap = canonicalize_varmap(varmap)
199+
missingvars = setdiff(unknowns(sys), collect(keys(varmap)))
200+
setobserved = filter(keys(varmap)) do var
201+
has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var))
202+
end
193203
p = todict(p)
194-
initprob = InitializationProblem(sys, t0, u0, p)
195-
initprobmap = getu(initprob, unknowns(sys))
196-
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
197-
getpunknowns = getu(initprob, punknowns)
198-
setpunknowns = setp(sys, punknowns)
199-
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
200-
reqd_syms = vcat(variable_symbols(initprob), parameter_symbols(initprob))
201-
update_initializeprob! = UpdateInitializeprob(
202-
getu(sys, reqd_syms), setu(initprob, reqd_syms))
203-
return initprob, update_initializeprob!, initprobmap, initprobpmap
204+
guesses = ModelingToolkit.guesses(sys)
205+
solvablepars = [par
206+
for par in parameters(sys)
207+
if is_parameter_solvable(par, p, defs, guesses)]
208+
if (((!isempty(missingvars) || !isempty(solvablepars) ||
209+
!isempty(setobserved)) &&
210+
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
211+
!isempty(initialization_equations(sys)))
212+
initprob = InitializationProblem(sys, t0, u0, p)
213+
initprobmap = getu(initprob, unknowns(sys))
214+
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
215+
getpunknowns = getu(initprob, punknowns)
216+
setpunknowns = setp(sys, punknowns)
217+
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
218+
reqd_syms = parameter_symbols(initprob)
219+
update_initializeprob! = UpdateInitializeprob(
220+
getu(sys, reqd_syms), setu(initprob, reqd_syms))
221+
return initprob, update_initializeprob!, initprobmap, initprobpmap
222+
else
223+
return nothing, nothing, nothing, nothing
224+
end
204225
end

0 commit comments

Comments
 (0)