Skip to content

Commit eb9d066

Browse files
feat: propagate ODEProblem guesses to remake
1 parent 46c173d commit eb9d066

File tree

4 files changed

+60
-71
lines changed

4 files changed

+60
-71
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,11 +1306,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
13061306
elseif isempty(u0map) && get_initializesystem(sys) === nothing
13071307
isys = structural_simplify(
13081308
generate_initializesystem(
1309-
sys; initialization_eqs, check_units, pmap = parammap); fully_determined)
1309+
sys; initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
13101310
else
13111311
isys = structural_simplify(
13121312
generate_initializesystem(
1313-
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
1313+
sys; u0map, initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
13141314
end
13151315

13161316
uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])

src/systems/nonlinear/initializesystem.jl

Lines changed: 23 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ function generate_initializesystem(sys::ODESystem;
3030
# 1) process dummy derivatives and u0map into initialization system
3131
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
3232
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
33-
guesses = merge(get_guesses(sys), todict(guesses))
33+
additional_guesses = anydict(guesses)
34+
guesses = merge(get_guesses(sys), additional_guesses)
3435
schedule = getfield(sys, :schedule)
3536
if !isnothing(schedule)
3637
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
@@ -178,7 +179,7 @@ function generate_initializesystem(sys::ODESystem;
178179
for k in keys(defs)
179180
defs[k] = substitute(defs[k], paramsubs)
180181
end
181-
meta = InitializationSystemMetadata(Dict{Any, Any}(u0map), Dict{Any, Any}(pmap))
182+
meta = InitializationSystemMetadata(anydict(u0map), anydict(pmap), additional_guesses)
182183
return NonlinearSystem(eqs_ics,
183184
vars,
184185
pars;
@@ -193,6 +194,7 @@ end
193194
struct InitializationSystemMetadata
194195
u0map::Dict{Any, Any}
195196
pmap::Dict{Any, Any}
197+
additional_guesses::Dict{Any, Any}
196198
end
197199

198200
function is_parameter_solvable(p, pmap, defs, guesses)
@@ -263,75 +265,29 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
263265
return initprob, odefn.update_initializeprob!, odefn.initializeprobmap,
264266
odefn.initializeprobpmap
265267
end
266-
if u0 === missing || isempty(u0)
267-
u0 = Dict()
268-
elseif !(eltype(u0) <: Pair)
269-
u0 = Dict(unknowns(sys) .=> u0)
270-
end
271-
if p === missing
272-
p = Dict()
268+
dvs = unknowns(sys)
269+
ps = parameters(sys)
270+
u0map = to_varmap(u0, dvs)
271+
pmap = to_varmap(p, ps)
272+
guesses = Dict()
273+
if SciMLBase.has_initializeprob(odefn)
274+
oldsys = odefn.initializeprob.f.sys
275+
meta = get_metadata(oldsys)
276+
if meta isa InitializationSystemMetadata
277+
u0map = merge(meta.u0map, u0map)
278+
pmap = merge(meta.pmap, pmap)
279+
merge!(guesses, meta.additional_guesses)
280+
end
273281
end
274282
if t0 === nothing
275283
t0 = 0.0
276284
end
277-
u0 = todict(u0)
278-
defs = defaults(sys)
279-
varmap = merge(defs, u0)
280-
for k in collect(keys(varmap))
281-
if varmap[k] === nothing
282-
delete!(varmap, k)
283-
end
284-
end
285-
varmap = canonicalize_varmap(varmap)
286-
missingvars = setdiff(unknowns(sys), collect(keys(varmap)))
287-
setobserved = filter(keys(varmap)) do var
288-
has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var))
289-
end
290-
p = todict(p)
291-
guesses = ModelingToolkit.guesses(sys)
292-
solvablepars = [par
293-
for par in parameters(sys)
294-
if is_parameter_solvable(par, p, defs, guesses)]
295-
pvarmap = merge(defs, p)
296-
setparobserved = filter(keys(pvarmap)) do var
297-
has_parameter_dependency_with_lhs(sys, var)
298-
end
299-
if (((!isempty(missingvars) || !isempty(solvablepars) ||
300-
!isempty(setobserved) || !isempty(setparobserved)) &&
301-
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
302-
!isempty(initialization_equations(sys)))
303-
if SciMLBase.has_initializeprob(odefn)
304-
oldsys = odefn.initializeprob.f.sys
305-
meta = get_metadata(oldsys)
306-
if meta isa InitializationSystemMetadata
307-
u0 = merge(meta.u0map, u0)
308-
p = merge(meta.pmap, p)
309-
end
310-
end
311-
for k in collect(keys(u0))
312-
if u0[k] === nothing
313-
delete!(u0, k)
314-
end
315-
end
316-
for k in collect(keys(p))
317-
if p[k] === nothing
318-
delete!(p, k)
319-
end
320-
end
321-
322-
initprob = InitializationProblem(sys, t0, u0, p)
323-
initprobmap = getu(initprob, unknowns(sys))
324-
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
325-
getpunknowns = getu(initprob, punknowns)
326-
setpunknowns = setp(sys, punknowns)
327-
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
328-
reqd_syms = parameter_symbols(initprob)
329-
update_initializeprob! = UpdateInitializeprob(
330-
getu(sys, reqd_syms), setu(initprob, reqd_syms))
331-
return initprob, update_initializeprob!, initprobmap, initprobpmap
332-
else
333-
return nothing, nothing, nothing, nothing
334-
end
285+
filter_missing_values!(u0map)
286+
filter_missing_values!(pmap)
287+
f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0)
288+
kws = f.kwargs
289+
return get(kws, :initializeprob, nothing), get(kws, :update_initializeprob!, nothing), get(kws, :initializeprobmap, nothing),
290+
get(kws, :initializeprobpmap, nothing)
335291
end
336292

337293
"""

src/systems/problem_utils.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ const AnyDict = Dict{Any, Any}
44
$(TYPEDSIGNATURES)
55
66
If called without arguments, return `Dict{Any, Any}`. Otherwise, interpret the input
7-
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`
8-
and `nothing`.
7+
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`,
8+
`missing` and `nothing`.
99
"""
1010
anydict() = AnyDict()
1111
anydict(::SciMLBase.NullParameters) = AnyDict()
1212
anydict(::Nothing) = AnyDict()
13+
anydict(::Missing) = AnyDict()
1314
anydict(x::AnyDict) = x
1415
anydict(x) = AnyDict(x)
1516

@@ -388,6 +389,15 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
388389
end
389390
end
390391

392+
"""
393+
$(TYPEDSIGNATURES)
394+
395+
Remove keys in `varmap` whose values are `nothing`.
396+
"""
397+
function filter_missing_values!(varmap::AbstractDict)
398+
filter!(kvp -> kvp[2] !== nothing, varmap)
399+
end
400+
391401
struct GetUpdatedMTKParameters{G, S}
392402
# `getu` functor which gets parameters that are unknowns during initialization
393403
getpunknowns::G

test/initializationsystem.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,26 @@ end
947947

948948
@test_nowarn remake(prob, p = prob.p)
949949
end
950+
951+
@testset "Guesses provided to `ODEProblem` are used in `remake`" begin
952+
@variables x(t) y(t)=2x
953+
@parameters p q=3x
954+
@mtkbuild sys = ODESystem([D(x) ~ x * p + q, x^3 + y^3 ~ 3], t)
955+
prob = ODEProblem(
956+
sys, [], (0.0, 1.0), [p => 1.0]; guesses = [x => 1.0, y => 1.0, q => 1.0])
957+
@test prob[x] == 0.0
958+
@test prob[y] == 0.0
959+
@test prob.ps[p] == 1.0
960+
@test prob.ps[q] == 0.0
961+
integ = init(prob)
962+
@test integ[x] 1 / cbrt(3)
963+
@test integ[y] 2 / cbrt(3)
964+
@test integ.ps[p] == 1.0
965+
@test integ.ps[q] 3 / cbrt(3)
966+
prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x])
967+
integ2 = init(prob2)
968+
@test integ2[x] cbrt(3 / 28)
969+
@test integ2[y] 3cbrt(3 / 28)
970+
@test integ2.ps[p] == 1.0
971+
@test integ2.ps[q] 2cbrt(3 / 28)
972+
end

0 commit comments

Comments
 (0)