Skip to content

Commit 8d4f542

Browse files
feat: propagate ODEProblem guesses to remake
1 parent 9589a1f commit 8d4f542

File tree

4 files changed

+60
-71
lines changed

4 files changed

+60
-71
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,11 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
13101310
elseif isempty(u0map) && get_initializesystem(sys) === nothing
13111311
isys = structural_simplify(
13121312
generate_initializesystem(
1313-
sys; initialization_eqs, check_units, pmap = parammap); fully_determined)
1313+
sys; initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
13141314
else
13151315
isys = structural_simplify(
13161316
generate_initializesystem(
1317-
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
1317+
sys; u0map, initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
13181318
end
13191319

13201320
ts = get_tearing_state(isys)

src/systems/nonlinear/initializesystem.jl

Lines changed: 23 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ function generate_initializesystem(sys::ODESystem;
3030
# 1) process dummy derivatives and u0map into initialization system
3131
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
3232
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
33-
guesses = merge(get_guesses(sys), todict(guesses))
33+
additional_guesses = anydict(guesses)
34+
guesses = merge(get_guesses(sys), additional_guesses)
3435
schedule = getfield(sys, :schedule)
3536
if !isnothing(schedule)
3637
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
@@ -178,7 +179,7 @@ function generate_initializesystem(sys::ODESystem;
178179
for k in keys(defs)
179180
defs[k] = substitute(defs[k], paramsubs)
180181
end
181-
meta = InitializationSystemMetadata(Dict{Any, Any}(u0map), Dict{Any, Any}(pmap))
182+
meta = InitializationSystemMetadata(anydict(u0map), anydict(pmap), additional_guesses)
182183
return NonlinearSystem(eqs_ics,
183184
vars,
184185
pars;
@@ -193,6 +194,7 @@ end
193194
struct InitializationSystemMetadata
194195
u0map::Dict{Any, Any}
195196
pmap::Dict{Any, Any}
197+
additional_guesses::Dict{Any, Any}
196198
end
197199

198200
function is_parameter_solvable(p, pmap, defs, guesses)
@@ -263,75 +265,29 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
263265
return initprob, odefn.update_initializeprob!, odefn.initializeprobmap,
264266
odefn.initializeprobpmap
265267
end
266-
if u0 === missing || isempty(u0)
267-
u0 = Dict()
268-
elseif !(eltype(u0) <: Pair)
269-
u0 = Dict(unknowns(sys) .=> u0)
270-
end
271-
if p === missing
272-
p = Dict()
268+
dvs = unknowns(sys)
269+
ps = parameters(sys)
270+
u0map = to_varmap(u0, dvs)
271+
pmap = to_varmap(p, ps)
272+
guesses = Dict()
273+
if SciMLBase.has_initializeprob(odefn)
274+
oldsys = odefn.initializeprob.f.sys
275+
meta = get_metadata(oldsys)
276+
if meta isa InitializationSystemMetadata
277+
u0map = merge(meta.u0map, u0map)
278+
pmap = merge(meta.pmap, pmap)
279+
merge!(guesses, meta.additional_guesses)
280+
end
273281
end
274282
if t0 === nothing
275283
t0 = 0.0
276284
end
277-
u0 = todict(u0)
278-
defs = defaults(sys)
279-
varmap = merge(defs, u0)
280-
for k in collect(keys(varmap))
281-
if varmap[k] === nothing
282-
delete!(varmap, k)
283-
end
284-
end
285-
varmap = canonicalize_varmap(varmap)
286-
missingvars = setdiff(unknowns(sys), collect(keys(varmap)))
287-
setobserved = filter(keys(varmap)) do var
288-
has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var))
289-
end
290-
p = todict(p)
291-
guesses = ModelingToolkit.guesses(sys)
292-
solvablepars = [par
293-
for par in parameters(sys)
294-
if is_parameter_solvable(par, p, defs, guesses)]
295-
pvarmap = merge(defs, p)
296-
setparobserved = filter(keys(pvarmap)) do var
297-
has_parameter_dependency_with_lhs(sys, var)
298-
end
299-
if (((!isempty(missingvars) || !isempty(solvablepars) ||
300-
!isempty(setobserved) || !isempty(setparobserved)) &&
301-
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
302-
!isempty(initialization_equations(sys)))
303-
if SciMLBase.has_initializeprob(odefn)
304-
oldsys = odefn.initializeprob.f.sys
305-
meta = get_metadata(oldsys)
306-
if meta isa InitializationSystemMetadata
307-
u0 = merge(meta.u0map, u0)
308-
p = merge(meta.pmap, p)
309-
end
310-
end
311-
for k in collect(keys(u0))
312-
if u0[k] === nothing
313-
delete!(u0, k)
314-
end
315-
end
316-
for k in collect(keys(p))
317-
if p[k] === nothing
318-
delete!(p, k)
319-
end
320-
end
321-
322-
initprob = InitializationProblem(sys, t0, u0, p)
323-
initprobmap = getu(initprob, unknowns(sys))
324-
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
325-
getpunknowns = getu(initprob, punknowns)
326-
setpunknowns = setp(sys, punknowns)
327-
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
328-
reqd_syms = parameter_symbols(initprob)
329-
update_initializeprob! = UpdateInitializeprob(
330-
getu(sys, reqd_syms), setu(initprob, reqd_syms))
331-
return initprob, update_initializeprob!, initprobmap, initprobpmap
332-
else
333-
return nothing, nothing, nothing, nothing
334-
end
285+
filter_missing_values!(u0map)
286+
filter_missing_values!(pmap)
287+
f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0)
288+
kws = f.kwargs
289+
return get(kws, :initializeprob, nothing), get(kws, :update_initializeprob!, nothing), get(kws, :initializeprobmap, nothing),
290+
get(kws, :initializeprobpmap, nothing)
335291
end
336292

337293
"""

src/systems/problem_utils.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ const AnyDict = Dict{Any, Any}
44
$(TYPEDSIGNATURES)
55
66
If called without arguments, return `Dict{Any, Any}`. Otherwise, interpret the input
7-
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`
8-
and `nothing`.
7+
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`,
8+
`missing` and `nothing`.
99
"""
1010
anydict() = AnyDict()
1111
anydict(::SciMLBase.NullParameters) = AnyDict()
1212
anydict(::Nothing) = AnyDict()
13+
anydict(::Missing) = AnyDict()
1314
anydict(x::AnyDict) = x
1415
anydict(x) = AnyDict(x)
1516

@@ -388,6 +389,15 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
388389
end
389390
end
390391

392+
"""
393+
$(TYPEDSIGNATURES)
394+
395+
Remove keys in `varmap` whose values are `nothing`.
396+
"""
397+
function filter_missing_values!(varmap::AbstractDict)
398+
filter!(kvp -> kvp[2] !== nothing, varmap)
399+
end
400+
391401
struct GetUpdatedMTKParameters{G, S}
392402
# `getu` functor which gets parameters that are unknowns during initialization
393403
getpunknowns::G

test/initializationsystem.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,3 +975,26 @@ end
975975
@test integ.ps[p] 1.0
976976
@test integ.ps[q]cbrt(2) rtol=1e-6
977977
end
978+
979+
@testset "Guesses provided to `ODEProblem` are used in `remake`" begin
980+
@variables x(t) y(t)=2x
981+
@parameters p q=3x
982+
@mtkbuild sys = ODESystem([D(x) ~ x * p + q, x^3 + y^3 ~ 3], t)
983+
prob = ODEProblem(
984+
sys, [], (0.0, 1.0), [p => 1.0]; guesses = [x => 1.0, y => 1.0, q => 1.0])
985+
@test prob[x] == 0.0
986+
@test prob[y] == 0.0
987+
@test prob.ps[p] == 1.0
988+
@test prob.ps[q] == 0.0
989+
integ = init(prob)
990+
@test integ[x] 1 / cbrt(3)
991+
@test integ[y] 2 / cbrt(3)
992+
@test integ.ps[p] == 1.0
993+
@test integ.ps[q] 3 / cbrt(3)
994+
prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x])
995+
integ2 = init(prob2)
996+
@test integ2[x] cbrt(3 / 28)
997+
@test integ2[y] 3cbrt(3 / 28)
998+
@test integ2.ps[p] == 1.0
999+
@test integ2.ps[q] 2cbrt(3 / 28)
1000+
end

0 commit comments

Comments
 (0)