Skip to content

Commit 4d4ff85

Browse files
Merge pull request #3226 from AayushSabharwal/as/remake-propagate-guesses
feat: propagate `ODEProblem` guesses to `remake`
2 parents e9fe9a1 + 10cc9c1 commit 4d4ff85

File tree

7 files changed

+183
-82
lines changed

7 files changed

+183
-82
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ REPL = "1"
126126
RecursiveArrayTools = "3.26"
127127
Reexport = "0.2, 1"
128128
RuntimeGeneratedFunctions = "0.5.9"
129-
SciMLBase = "2.57.1"
129+
SciMLBase = "2.64"
130130
SciMLStructures = "1.0"
131131
Serialization = "1"
132132
Setfield = "0.7, 0.8, 1"

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,11 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
13101310
elseif isempty(u0map) && get_initializesystem(sys) === nothing
13111311
isys = structural_simplify(
13121312
generate_initializesystem(
1313-
sys; initialization_eqs, check_units, pmap = parammap); fully_determined)
1313+
sys; initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
13141314
else
13151315
isys = structural_simplify(
13161316
generate_initializesystem(
1317-
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
1317+
sys; u0map, initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
13181318
end
13191319

13201320
ts = get_tearing_state(isys)

src/systems/nonlinear/initializesystem.jl

Lines changed: 64 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ function generate_initializesystem(sys::ODESystem;
3030
# 1) process dummy derivatives and u0map into initialization system
3131
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
3232
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
33-
guesses = merge(get_guesses(sys), todict(guesses))
33+
additional_guesses = anydict(guesses)
34+
guesses = merge(get_guesses(sys), additional_guesses)
3435
schedule = getfield(sys, :schedule)
3536
if !isnothing(schedule)
3637
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
@@ -178,7 +179,7 @@ function generate_initializesystem(sys::ODESystem;
178179
for k in keys(defs)
179180
defs[k] = substitute(defs[k], paramsubs)
180181
end
181-
meta = InitializationSystemMetadata(Dict{Any, Any}(u0map), Dict{Any, Any}(pmap))
182+
meta = InitializationSystemMetadata(anydict(u0map), anydict(pmap), additional_guesses)
182183
return NonlinearSystem(eqs_ics,
183184
vars,
184185
pars;
@@ -193,6 +194,7 @@ end
193194
struct InitializationSystemMetadata
194195
u0map::Dict{Any, Any}
195196
pmap::Dict{Any, Any}
197+
additional_guesses::Dict{Any, Any}
196198
end
197199

198200
function is_parameter_solvable(p, pmap, defs, guesses)
@@ -208,17 +210,16 @@ function is_parameter_solvable(p, pmap, defs, guesses)
208210
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
209211
end
210212

211-
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
213+
function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newu0, newp)
212214
if u0 === missing && p === missing
213-
return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap,
214-
odefn.initializeprobpmap
215+
return odefn.initialization_data
215216
end
216217
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
217218
oldinitprob = odefn.initializeprob
218-
if oldinitprob === nothing || !SciMLBase.has_sys(oldinitprob.f) ||
219-
!(oldinitprob.f.sys isa NonlinearSystem)
220-
return oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap,
221-
odefn.initializeprobpmap
219+
oldinitprob === nothing && return nothing
220+
if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem)
221+
return SciMLBase.OverrideInitData(oldinitprob, odefn.update_initializeprob!,
222+
odefn.initializeprobmap, odefn.initializeprobpmap)
222223
end
223224
pidxs = ParameterIndex[]
224225
pvals = []
@@ -260,78 +261,69 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
260261
oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
261262
end
262263
initprob = remake(oldinitprob; u0 = newu0, p = newp)
263-
return initprob, odefn.update_initializeprob!, odefn.initializeprobmap,
264-
odefn.initializeprobpmap
264+
return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!,
265+
odefn.initializeprobmap, odefn.initializeprobpmap)
265266
end
266-
if u0 === missing || isempty(u0)
267-
u0 = Dict()
268-
elseif !(eltype(u0) <: Pair)
269-
u0 = Dict(unknowns(sys) .=> u0)
270-
end
271-
if p === missing
272-
p = Dict()
273-
end
274-
if t0 === nothing
275-
t0 = 0.0
276-
end
277-
u0 = todict(u0)
267+
dvs = unknowns(sys)
268+
ps = parameters(sys)
269+
u0map = to_varmap(u0, dvs)
270+
symbols_to_symbolics!(sys, u0map)
271+
pmap = to_varmap(p, ps)
272+
symbols_to_symbolics!(sys, pmap)
273+
guesses = Dict()
278274
defs = defaults(sys)
279-
varmap = merge(defs, u0)
280-
for k in collect(keys(varmap))
281-
if varmap[k] === nothing
282-
delete!(varmap, k)
275+
if SciMLBase.has_initializeprob(odefn)
276+
oldsys = odefn.initializeprob.f.sys
277+
meta = get_metadata(oldsys)
278+
if meta isa InitializationSystemMetadata
279+
u0map = merge(meta.u0map, u0map)
280+
pmap = merge(meta.pmap, pmap)
281+
merge!(guesses, meta.additional_guesses)
283282
end
284-
end
285-
varmap = canonicalize_varmap(varmap)
286-
missingvars = setdiff(unknowns(sys), collect(keys(varmap)))
287-
setobserved = filter(keys(varmap)) do var
288-
has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var))
289-
end
290-
p = todict(p)
291-
guesses = ModelingToolkit.guesses(sys)
292-
solvablepars = [par
293-
for par in parameters(sys)
294-
if is_parameter_solvable(par, p, defs, guesses)]
295-
pvarmap = merge(defs, p)
296-
setparobserved = filter(keys(pvarmap)) do var
297-
has_parameter_dependency_with_lhs(sys, var)
298-
end
299-
if (((!isempty(missingvars) || !isempty(solvablepars) ||
300-
!isempty(setobserved) || !isempty(setparobserved)) &&
301-
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
302-
!isempty(initialization_equations(sys)))
303-
if SciMLBase.has_initializeprob(odefn)
304-
oldsys = odefn.initializeprob.f.sys
305-
meta = get_metadata(oldsys)
306-
if meta isa InitializationSystemMetadata
307-
u0 = merge(meta.u0map, u0)
308-
p = merge(meta.pmap, p)
283+
else
284+
# there is no initializeprob, so the original problem construction
285+
# had no solvable parameters and had the differential variables
286+
# specified in `u0map`.
287+
if u0 === missing
288+
# the user didn't pass `u0` to `remake`, so they want to retain
289+
# existing values. Fill the differential variables in `u0map`,
290+
# initialization will either be elided or solve for the algebraic
291+
# variables
292+
diff_idxs = isdiffeq.(equations(sys))
293+
for i in eachindex(dvs)
294+
diff_idxs[i] || continue
295+
u0map[dvs[i]] = newu0[i]
309296
end
310297
end
311-
for k in collect(keys(u0))
312-
if u0[k] === nothing
313-
delete!(u0, k)
298+
if p === missing
299+
# the user didn't pass `p` to `remake`, so they want to retain
300+
# existing values. Fill all parameters in `pmap` so that none of
301+
# them are solvable.
302+
for p in ps
303+
pmap[p] = getp(sys, p)(newp)
314304
end
315305
end
316-
for k in collect(keys(p))
317-
if p[k] === nothing
318-
delete!(p, k)
319-
end
306+
# all non-solvable parameters need values regardless
307+
for p in ps
308+
haskey(pmap, p) && continue
309+
is_parameter_solvable(p, pmap, defs, guesses) && continue
310+
pmap[p] = getp(sys, p)(newp)
320311
end
321-
322-
initprob = InitializationProblem(sys, t0, u0, p)
323-
initprobmap = getu(initprob, unknowns(sys))
324-
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
325-
getpunknowns = getu(initprob, punknowns)
326-
setpunknowns = setp(sys, punknowns)
327-
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
328-
reqd_syms = parameter_symbols(initprob)
329-
update_initializeprob! = UpdateInitializeprob(
330-
getu(sys, reqd_syms), setu(initprob, reqd_syms))
331-
return initprob, update_initializeprob!, initprobmap, initprobpmap
332-
else
333-
return nothing, nothing, nothing, nothing
334312
end
313+
if t0 === nothing
314+
t0 = 0.0
315+
end
316+
filter_missing_values!(u0map)
317+
filter_missing_values!(pmap)
318+
f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0)
319+
kws = f.kwargs
320+
initprob = get(kws, :initializeprob, nothing)
321+
if initprob === nothing
322+
return nothing
323+
end
324+
return SciMLBase.OverrideInitData(initprob, get(kws, :update_initializeprob!, nothing),
325+
get(kws, :initializeprobmap, nothing),
326+
get(kws, :initializeprobpmap, nothing))
335327
end
336328

337329
"""

src/systems/optimization/optimizationsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ function OptimizationSystem(objective; constraints = [], kwargs...)
168168
push!(new_ps, p)
169169
end
170170
end
171-
return OptimizationSystem(objective, collect(allunknowns), collect(new_ps); constraints, kwargs...)
171+
return OptimizationSystem(
172+
objective, collect(allunknowns), collect(new_ps); constraints, kwargs...)
172173
end
173174

174175
function flatten(sys::OptimizationSystem)

src/systems/problem_utils.jl

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ const AnyDict = Dict{Any, Any}
44
$(TYPEDSIGNATURES)
55
66
If called without arguments, return `Dict{Any, Any}`. Otherwise, interpret the input
7-
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`
8-
and `nothing`.
7+
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`,
8+
`missing` and `nothing`.
99
"""
1010
anydict() = AnyDict()
1111
anydict(::SciMLBase.NullParameters) = AnyDict()
1212
anydict(::Nothing) = AnyDict()
13+
anydict(::Missing) = AnyDict()
1314
anydict(x::AnyDict) = x
1415
anydict(x) = AnyDict(x)
1516

@@ -51,6 +52,42 @@ function add_toterms(varmap::AbstractDict; toterm = default_toterm)
5152
return cp
5253
end
5354

55+
"""
56+
$(TYPEDSIGNATURES)
57+
58+
Turn any `Symbol` keys in `varmap` to the appropriate symbolic variables in `sys`. Any
59+
symbols that cannot be converted are ignored.
60+
"""
61+
function symbols_to_symbolics!(sys::AbstractSystem, varmap::AbstractDict)
62+
if is_split(sys)
63+
ic = get_index_cache(sys)
64+
for k in collect(keys(varmap))
65+
k isa Symbol || continue
66+
newk = get(ic.symbol_to_variable, k, nothing)
67+
newk === nothing && continue
68+
varmap[newk] = varmap[k]
69+
delete!(varmap, k)
70+
end
71+
else
72+
syms = all_symbols(sys)
73+
for k in collect(keys(varmap))
74+
k isa Symbol || continue
75+
idx = findfirst(syms) do sym
76+
hasname(sym) || return false
77+
name = getname(sym)
78+
return name == k
79+
end
80+
idx === nothing && continue
81+
newk = syms[idx]
82+
if iscall(newk) && operation(newk) === getindex
83+
newk = arguments(newk)[1]
84+
end
85+
varmap[newk] = varmap[k]
86+
delete!(varmap, k)
87+
end
88+
end
89+
end
90+
5491
"""
5592
$(TYPEDSIGNATURES)
5693
@@ -388,6 +425,15 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
388425
end
389426
end
390427

428+
"""
429+
$(TYPEDSIGNATURES)
430+
431+
Remove keys in `varmap` whose values are `nothing`.
432+
"""
433+
function filter_missing_values!(varmap::AbstractDict)
434+
filter!(kvp -> kvp[2] !== nothing, varmap)
435+
end
436+
391437
struct GetUpdatedMTKParameters{G, S}
392438
# `getu` functor which gets parameters that are unknowns during initialization
393439
getpunknowns::G
@@ -431,12 +477,16 @@ end
431477
$(TYPEDEF)
432478
433479
A simple utility meant to be used as the `constructor` passed to `process_SciMLProblem` in
434-
case constructing a SciMLFunction is not required.
480+
case constructing a SciMLFunction is not required. The arguments passed to it are available
481+
in the `args` field, and the keyword arguments in the `kwargs` field.
435482
"""
436-
struct EmptySciMLFunction end
483+
struct EmptySciMLFunction{A, K}
484+
args::A
485+
kwargs::K
486+
end
437487

438488
function EmptySciMLFunction(args...; kwargs...)
439-
return nothing
489+
return EmptySciMLFunction{typeof(args), typeof(kwargs)}(args, kwargs)
440490
end
441491

442492
"""
@@ -516,8 +566,10 @@ function process_SciMLProblem(
516566
pType = typeof(pmap)
517567
_u0map = u0map
518568
u0map = to_varmap(u0map, dvs)
569+
symbols_to_symbolics!(sys, u0map)
519570
_pmap = pmap
520571
pmap = to_varmap(pmap, ps)
572+
symbols_to_symbolics!(sys, pmap)
521573
defs = add_toterms(recursive_unwrap(defaults(sys)))
522574
cmap, cs = get_cmap(sys)
523575
kwargs = NamedTuple(kwargs)

src/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,6 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ
569569
return nothing
570570
end
571571

572-
573572
function collect_var!(unknowns, parameters, var, iv; depth = 0)
574573
isequal(var, iv) && return nothing
575574
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing

test/initializationsystem.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,3 +975,60 @@ end
975975
@test integ.ps[p] 1.0
976976
@test integ.ps[q]cbrt(2) rtol=1e-6
977977
end
978+
979+
@testset "Guesses provided to `ODEProblem` are used in `remake`" begin
980+
@variables x(t) y(t)=2x
981+
@parameters p q=3x
982+
@mtkbuild sys = ODESystem([D(x) ~ x * p + q, x^3 + y^3 ~ 3], t)
983+
prob = ODEProblem(
984+
sys, [], (0.0, 1.0), [p => 1.0]; guesses = [x => 1.0, y => 1.0, q => 1.0])
985+
@test prob[x] == 0.0
986+
@test prob[y] == 0.0
987+
@test prob.ps[p] == 1.0
988+
@test prob.ps[q] == 0.0
989+
integ = init(prob)
990+
@test integ[x] 1 / cbrt(3)
991+
@test integ[y] 2 / cbrt(3)
992+
@test integ.ps[p] == 1.0
993+
@test integ.ps[q] 3 / cbrt(3)
994+
prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x])
995+
integ2 = init(prob2)
996+
@test integ2[x] cbrt(3 / 28)
997+
@test integ2[y] 3cbrt(3 / 28)
998+
@test integ2.ps[p] == 1.0
999+
@test integ2.ps[q] 2cbrt(3 / 28)
1000+
end
1001+
1002+
@testset "Remake problem with no initializeprob" begin
1003+
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
1004+
@parameters p [guess = 1.0] q [guess = 1.0]
1005+
@mtkbuild sys = ODESystem(
1006+
[D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p])
1007+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
1008+
@test prob.f.initialization_data === nothing
1009+
prob2 = remake(prob; u0 = [x => 2.0])
1010+
@test prob2[x] == 2.0
1011+
@test prob2.f.initialization_data === nothing
1012+
prob3 = remake(prob; u0 = [y => 2.0])
1013+
@test prob3.f.initialization_data !== nothing
1014+
@test init(prob3)[x] 1.0
1015+
prob4 = remake(prob; p = [p => 1.0])
1016+
@test prob4.f.initialization_data === nothing
1017+
prob5 = remake(prob; p = [p => missing, q => 2.0])
1018+
@test prob5.f.initialization_data !== nothing
1019+
@test init(prob5).ps[p] 1.0
1020+
end
1021+
1022+
@testset "Variables provided as symbols" begin
1023+
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
1024+
@parameters p [guess = 1.0] q [guess = 1.0]
1025+
@mtkbuild sys = ODESystem(
1026+
[D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p])
1027+
prob = ODEProblem(sys, [:x => 1.0], (0.0, 1.0), [p => 1.0])
1028+
@test prob.f.initialization_data === nothing
1029+
prob2 = remake(prob; u0 = [:x => 2.0])
1030+
@test prob2.f.initialization_data === nothing
1031+
prob3 = remake(prob; u0 = [:y => 1.0])
1032+
@test prob3.f.initialization_data !== nothing
1033+
@test init(prob3)[x] 0.5
1034+
end

0 commit comments

Comments
 (0)