Skip to content

Commit b2dbf91

Browse files
feat: allow parameters to be unknowns in the initialization system
1 parent 8ce64bf commit b2dbf91

File tree

5 files changed

+207
-17
lines changed

5 files changed

+207
-17
lines changed

src/systems/abstractsystem.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,15 @@ function has_observed_with_lhs(sys, sym)
736736
end
737737
end
738738

739+
function has_parameter_dependency_with_lhs(sys, sym)
740+
has_parameter_dependencies(sys) || return false
741+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
742+
return any(isequal(sym), ic.dependent_pars)
743+
else
744+
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
745+
end
746+
end
747+
739748
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym)
740749
if is_variable(sys, sym) || is_independent_variable(sys, sym)
741750
push!(ts_idxs, ContinuousTimeseries())

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
357357
analytic = nothing,
358358
split_idxs = nothing,
359359
initializeprob = nothing,
360+
update_initializeprob! = nothing,
360361
initializeprobmap = nothing,
362+
initializeprobpmap = nothing,
361363
kwargs...) where {iip, specialize}
362364
if !iscomplete(sys)
363365
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
@@ -459,7 +461,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
459461
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
460462
analytic = analytic,
461463
initializeprob = initializeprob,
462-
initializeprobmap = initializeprobmap)
464+
update_initializeprob! = update_initializeprob!,
465+
initializeprobmap = initializeprobmap,
466+
initializeprobpmap = initializeprobpmap)
463467
end
464468

465469
"""
@@ -789,6 +793,45 @@ function get_u0(
789793
return u0, defs
790794
end
791795

796+
struct GetUpdatedMTKParameters{G, S}
797+
# `getu` functor which gets parameters that are unknowns during initialization
798+
getpunknowns::G
799+
# `setu` functor which returns a modified MTKParameters using those parameters
800+
setpunknowns::S
801+
end
802+
803+
function (f::GetUpdatedMTKParameters)(prob, initializesol)
804+
mtkp = copy(parameter_values(prob))
805+
f.setpunknowns(mtkp, f.getpunknowns(initializesol))
806+
mtkp
807+
end
808+
809+
struct UpdateInitializeprob{G, S}
810+
# `getu` functor which gets all values from prob
811+
getvals::G
812+
# `setu` functor which updates initializeprob with values
813+
setvals::S
814+
end
815+
816+
function (f::UpdateInitializeprob)(initializeprob, prob)
817+
f.setvals(initializeprob, f.getvals(prob))
818+
end
819+
820+
function get_temporary_value(p)
821+
stype = symtype(unwrap(p))
822+
return if stype == Real
823+
zero(Float64)
824+
elseif stype <: AbstractArray{Real}
825+
zeros(Float64, size(p))
826+
elseif stype <: Real
827+
zero(stype)
828+
elseif stype <: AbstractArray
829+
zeros(eltype(stype), size(p))
830+
else
831+
error("Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization")
832+
end
833+
end
834+
792835
function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
793836
implicit_dae = false, du0map = nothing,
794837
version = nothing, tgrad = false,
@@ -829,18 +872,38 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
829872
end
830873

831874
if eltype(parammap) <: Pair
832-
parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap))
875+
parammap = Dict{Any, Any}(unwrap(k) => v for (k, v) in parammap)
833876
elseif parammap isa AbstractArray
834877
if isempty(parammap)
835878
parammap = SciMLBase.NullParameters()
836879
else
837-
parammap = Dict(unwrap.(parameters(sys)) .=> parammap)
880+
parammap = Dict{Any, Any}(unwrap.(parameters(sys)) .=> parammap)
838881
end
839882
end
840-
883+
defs = defaults(sys)
884+
if has_guesses(sys)
885+
guesses = merge(
886+
ModelingToolkit.guesses(sys), isempty(guesses) ? Dict() : todict(guesses))
887+
solvablepars = [p
888+
for p in parameters(sys)
889+
if is_parameter_solvable(p, parammap, defs, guesses)]
890+
891+
pvarmap = if parammap === nothing || parammap == SciMLBase.NullParameters() || !(eltype(parammap) <: Pair) && isempty(parammap)
892+
defs
893+
else
894+
merge(defs, todict(parammap))
895+
end
896+
setparobserved = filter(keys(pvarmap)) do var
897+
has_parameter_dependency_with_lhs(sys, var)
898+
end
899+
else
900+
solvablepars = ()
901+
setparobserved = ()
902+
end
841903
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
842904
if sys isa ODESystem && build_initializeprob &&
843-
(((implicit_dae || !isempty(missingvars) || !isempty(setobserved)) &&
905+
(((implicit_dae || !isempty(missingvars) || !isempty(solvablepars) ||
906+
!isempty(setobserved) || !isempty(setparobserved)) &&
844907
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
845908
!isempty(initialization_equations(sys))) && t !== nothing
846909
if eltype(u0map) <: Number
@@ -854,14 +917,32 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
854917
sys, t, u0map, parammap; guesses, warn_initialize_determined,
855918
initialization_eqs, eval_expression, eval_module, fully_determined, check_units)
856919
initializeprobmap = getu(initializeprob, unknowns(sys))
920+
punknowns = [p
921+
for p in all_variable_symbols(initializeprob) if is_parameter(sys, p)]
922+
getpunknowns = getu(initializeprob, punknowns)
923+
setpunknowns = setp(sys, punknowns)
924+
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
925+
reqd_syms = parameter_symbols(initializeprob)
926+
update_initializeprob! = UpdateInitializeprob(
927+
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
857928

858929
zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
930+
if parammap isa SciMLBase.NullParameters
931+
parammap = Dict()
932+
end
933+
for p in punknowns
934+
p = unwrap(p)
935+
stype = symtype(p)
936+
parammap[p] = get_temporary_value(p)
937+
end
859938
trueinit = collect(merge(zerovars, eltype(u0map) <: Pair ? todict(u0map) : u0map))
860939
u0map isa StaticArraysCore.StaticArray &&
861940
(trueinit = SVector{length(trueinit)}(trueinit))
862941
else
863942
initializeprob = nothing
943+
update_initializeprob! = nothing
864944
initializeprobmap = nothing
945+
initializeprobpmap = nothing
865946
trueinit = u0map
866947
end
867948

@@ -909,7 +990,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
909990
sparse = sparse, eval_expression = eval_expression,
910991
eval_module = eval_module,
911992
initializeprob = initializeprob,
993+
update_initializeprob! = update_initializeprob!,
912994
initializeprobmap = initializeprobmap,
995+
initializeprobpmap = initializeprobpmap,
913996
kwargs...)
914997
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
915998
end
@@ -1471,10 +1554,12 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
14711554
isys = get_initializesystem(sys; initialization_eqs, check_units)
14721555
elseif isempty(u0map) && get_initializesystem(sys) === nothing
14731556
isys = structural_simplify(
1474-
generate_initializesystem(sys; initialization_eqs, check_units); fully_determined)
1557+
generate_initializesystem(
1558+
sys; initialization_eqs, check_units, pmap = parammap); fully_determined)
14751559
else
14761560
isys = structural_simplify(
1477-
generate_initializesystem(sys; u0map, initialization_eqs, check_units); fully_determined)
1561+
generate_initializesystem(
1562+
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
14781563
end
14791564

14801565
uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])
@@ -1498,14 +1583,15 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
14981583
parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
14991584
[get_iv(sys) => t] :
15001585
merge(todict(parammap), Dict(get_iv(sys) => t))
1586+
parammap = Dict(k => v for (k, v) in parammap if v !== missing)
15011587
if isempty(u0map)
15021588
u0map = Dict()
15031589
end
15041590
if isempty(guesses)
15051591
guesses = Dict()
15061592
end
15071593

1508-
u0map = merge(todict(guesses), todict(u0map))
1594+
u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), todict(u0map))
15091595
if neqs == nunknown
15101596
NonlinearProblem(isys, u0map, parammap; kwargs...)
15111597
else

src/systems/nonlinear/initializesystem.jl

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
55
"""
66
function generate_initializesystem(sys::ODESystem;
77
u0map = Dict(),
8+
pmap = Dict(),
89
initialization_eqs = [],
910
guesses = Dict(),
1011
default_dd_guess = 0.0,
@@ -74,12 +75,103 @@ function generate_initializesystem(sys::ODESystem;
7475
end
7576
end
7677

77-
pars = [parameters(sys); get_iv(sys)] # include independent variable as pseudo-parameter
78-
eqs_ics = [eqs_ics; observed(sys)]
79-
return NonlinearSystem(
80-
eqs_ics, vars, pars;
81-
defaults = defs, parameter_dependencies = parameter_dependencies(sys),
82-
checks = check_units,
83-
name, kwargs...
78+
# 4) process parameters as initialization unknowns
79+
paramsubs = Dict()
80+
if pmap isa SciMLBase.NullParameters
81+
pmap = Dict()
82+
end
83+
pmap = todict(pmap)
84+
for p in parameters(sys)
85+
if is_parameter_solvable(p, pmap, defs, guesses)
86+
# If either of them are `missing` the parameter is an unknown
87+
# But if the parameter is passed a value, use that as an additional
88+
# equation in the system
89+
_val1 = get(pmap, p, nothing)
90+
_val2 = get(defs, p, nothing)
91+
_val3 = get(guesses, p, nothing)
92+
varp = tovar(p)
93+
paramsubs[p] = varp
94+
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
95+
if _val2 === missing
96+
if _val1 !== nothing && _val1 !== missing
97+
push!(eqs_ics, varp ~ _val1)
98+
push!(defs, varp => _val1)
99+
elseif _val3 !== nothing
100+
# assuming an equation exists (either via algebraic equations or initialization_eqs)
101+
push!(defs, varp => _val3)
102+
elseif check_defguess
103+
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
104+
end
105+
# `missing` passed to `ODEProblem`, and (either an equation using default or a guess)
106+
elseif _val1 === missing
107+
if _val2 !== nothing && _val2 !== missing
108+
push!(eqs_ics, varp ~ _val2)
109+
push!(defs, varp => _val2)
110+
elseif _val3 !== nothing
111+
push!(defs, varp => _val3)
112+
elseif check_defguess
113+
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
114+
end
115+
# given a symbolic value to ODEProblem
116+
elseif symbolic_type(_val1) != NotSymbolic()
117+
push!(eqs_ics, varp ~ _val1)
118+
push!(defs, varp => _val3)
119+
# No value passed to `ODEProblem`, but a default and a guess are present
120+
# _val2 !== missing is implied by it falling this far in the elseif chain
121+
elseif _val1 === nothing && _val2 !== nothing
122+
push!(eqs_ics, varp ~ _val2)
123+
push!(defs, varp => _val3)
124+
else
125+
# _val1 !== missing and _val1 !== nothing, so a value was provided to ODEProblem
126+
# This would mean `is_parameter_solvable` returned `false`, so we never end up
127+
# here
128+
error("This should never be reached")
129+
end
130+
end
131+
end
132+
133+
# 5) parameter dependencies become equations, their LHS become unknowns
134+
for eq in parameter_dependencies(sys)
135+
varp = tovar(eq.lhs)
136+
paramsubs[eq.lhs] = varp
137+
push!(eqs_ics, eq)
138+
guessval = get(guesses, eq.lhs, eq.rhs)
139+
push!(defs, varp => guessval)
140+
end
141+
142+
# 6) handle values provided for dependent parameters similar to values for observed variables
143+
for (k, v) in merge(defaults(sys), pmap)
144+
if is_variable_floatingpoint(k) && has_parameter_dependency_with_lhs(sys, k)
145+
push!(eqs_ics, paramsubs[k] ~ v)
146+
end
147+
end
148+
149+
# parameters do not include ones that became initialization unknowns
150+
pars = vcat(
151+
[get_iv(sys)], # include independent variable as pseudo-parameter
152+
[p for p in parameters(sys) if !haskey(paramsubs, p)]
84153
)
154+
155+
eqs_ics = Symbolics.substitute.([eqs_ics; observed(sys)], (paramsubs,))
156+
vars = [vars; collect(values(paramsubs))]
157+
for k in keys(defs)
158+
defs[k] = substitute(defs[k], paramsubs)
159+
end
160+
return NonlinearSystem(eqs_ics,
161+
vars,
162+
pars;
163+
defaults = defs,
164+
checks = check_units,
165+
name,
166+
kwargs...)
167+
end
168+
169+
function is_parameter_solvable(p, pmap, defs, guesses)
170+
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
171+
_val2 = get(defs, p, nothing)
172+
_val3 = get(guesses, p, nothing)
173+
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
174+
# the ODEProblem and it has a default and a guess)
175+
return ((_val1 === missing || _val2 === missing) ||
176+
(_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
85177
end

src/systems/parameter_buffer.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,9 @@ function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true
522522
ic = get_index_cache(indp_to_system(indp))
523523
for (idx, val) in zip(idxs, vals)
524524
sym = nothing
525+
if val === missing
526+
val = get_temporary_value(idx)
527+
end
525528
if symbolic_type(idx) == ScalarSymbolic()
526529
sym = idx
527530
idx = parameter_index(ic, sym)

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,8 @@ function collect_var!(unknowns, parameters, var, iv)
494494
push!(unknowns, var)
495495
end
496496
# Add also any parameters that appear only as defaults in the var
497-
if hasdefault(var)
498-
collect_vars!(unknowns, parameters, getdefault(var), iv)
497+
if hasdefault(var) && (def = getdefault(var)) !== missing
498+
collect_vars!(unknowns, parameters, def, iv)
499499
end
500500
return nothing
501501
end

0 commit comments

Comments
 (0)