Skip to content

Commit 4e7c065

Browse files
feat: support unknown parameters during initialization
1 parent 36a9909 commit 4e7c065

File tree

1 file changed

+54
-15
lines changed

1 file changed

+54
-15
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
324324
split_idxs = nothing,
325325
initializeprob = nothing,
326326
initializeprobmap = nothing,
327+
initializeprob_updatep! = nothing,
327328
kwargs...) where {iip, specialize}
328329
if !iscomplete(sys)
329330
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
@@ -506,7 +507,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
506507
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
507508
analytic = analytic,
508509
initializeprob = initializeprob,
509-
initializeprobmap = initializeprobmap)
510+
initializeprobmap = initializeprobmap,
511+
initializeprob_updatep! = initializeprob_updatep!)
510512
end
511513

512514
"""
@@ -538,6 +540,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
538540
checkbounds = false,
539541
initializeprob = nothing,
540542
initializeprobmap = nothing,
543+
initializeprob_updatep! = nothing,
541544
kwargs...) where {iip}
542545
if !iscomplete(sys)
543546
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
@@ -611,7 +614,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
611614
jac_prototype = jac_prototype,
612615
observed = observedfun,
613616
initializeprob = initializeprob,
614-
initializeprobmap = initializeprobmap)
617+
initializeprobmap = initializeprobmap,
618+
initializeprob_updatep! = initializeprob_updatep!)
615619
end
616620

617621
function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
@@ -862,7 +866,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
862866
varmap = canonicalize_varmap(varmap)
863867
varlist = collect(map(unwrap, dvs))
864868
missingvars = setdiff(varlist, collect(keys(varmap)))
865-
866869
# Append zeros to the variables which are determined by the initialization system
867870
# This essentially bypasses the check for if initial conditions are defined for DAEs
868871
# since they will be checked in the initialization problem's construction
@@ -873,11 +876,14 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
873876
parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap))
874877
elseif parammap isa AbstractArray
875878
if isempty(parammap)
876-
parammap = SciMLBase.NullParameters()
879+
parammap = Dict()
877880
else
878881
parammap = Dict(unwrap.(parameters(sys)) .=> parammap)
879882
end
883+
elseif parammap === nothing || parammap isa SciMLBase.NullParameters
884+
parammap = Dict()
880885
end
886+
missingpars = setdiff(parameters(sys), keys(parammap))
881887

882888
if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
883889
clockedparammap = Dict()
@@ -886,7 +892,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
886892
v = unwrap(v)
887893
is_discrete_domain(v) || continue
888894
op = operation(v)
889-
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
895+
if !isa(op, Symbolics.Operator) && !isempty(parammap) &&
890896
haskey(parammap, v)
891897
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
892898
end
@@ -909,7 +915,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
909915
# TODO: make it work with clocks
910916
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
911917
if sys isa ODESystem && build_initializeprob &&
912-
(implicit_dae || !isempty(missingvars)) &&
918+
(implicit_dae || !isempty(missingvars) || !isempty(missingpars)) &&
913919
all(isequal(Continuous()), ci.var_domain) &&
914920
ModelingToolkit.get_tearing_state(sys) !== nothing &&
915921
t !== nothing
@@ -921,15 +927,43 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
921927
end
922928
initializeprob = ModelingToolkit.InitializationProblem(
923929
sys, t, u0map, parammap; guesses, warn_initialize_determined)
924-
initializeprobmap = getu(initializeprob, unknowns(sys))
925-
930+
unks = unknowns(sys)
931+
initializeprobmap = isempty(unks) ? (_...) -> nothing :
932+
getu(initializeprob, unknowns(sys))
933+
if any(p -> is_variable(initializeprob, p) || is_observed(initializeprob, p),
934+
parameters(sys))
935+
punknowns = [p
936+
for p in parameters(sys)
937+
if is_variable(initializeprob, p) ||
938+
is_observed(initializeprob, p)]
939+
initializeprob_updatep! = let getter = getu(initializeprob, tovar.(punknowns)),
940+
setter = setp(sys, punknowns)
941+
942+
function (ps, initsol)
943+
setter(ps, getter(initsol))
944+
end
945+
end
946+
else
947+
punknowns = []
948+
initializeprob_updatep! = nothing
949+
end
926950
zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
951+
zeropars = Dict()
952+
for p in punknowns
953+
zeropars[p] = if Symbolics.isarraysymbolic(p)
954+
collect(unwrap.(zero(p)))
955+
else
956+
unwrap(zero(p))
957+
end
958+
end
927959
trueinit = collect(merge(zerovars, eltype(u0map) <: Pair ? todict(u0map) : u0map))
928960
u0map isa StaticArraysCore.StaticArray &&
929961
(trueinit = SVector{length(trueinit)}(trueinit))
930962
else
931963
initializeprob = nothing
932964
initializeprobmap = nothing
965+
initializeprob_updatep! = nothing
966+
zeropars = Dict()
933967
trueinit = u0map
934968
end
935969

@@ -940,7 +974,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
940974
parammap == SciMLBase.NullParameters() && isempty(defs)
941975
nothing
942976
else
943-
MTKParameters(sys, parammap, trueinit)
977+
if parammap === nothing || parammap == SciMLBase.NullParameters()
978+
parammap = Dict()
979+
else
980+
parammap = todict(parammap)
981+
end
982+
MTKParameters(sys, merge(parammap, zeropars), trueinit)
944983
end
945984
else
946985
u0, p, defs = get_u0_p(sys,
@@ -975,6 +1014,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
9751014
sparse = sparse, eval_expression = eval_expression,
9761015
initializeprob = initializeprob,
9771016
initializeprobmap = initializeprobmap,
1017+
initializeprob_updatep! = initializeprob_updatep!,
9781018
kwargs...)
9791019
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
9801020
end
@@ -1602,13 +1642,15 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
16021642
if !iscomplete(sys)
16031643
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
16041644
end
1645+
parammap = parammap isa SciMLBase.NullParameters ? Dict() : todict(parammap)
16051646
if isempty(u0map) && get_initializesystem(sys) !== nothing
16061647
isys = get_initializesystem(sys)
16071648
elseif isempty(u0map) && get_initializesystem(sys) === nothing
1608-
isys = structural_simplify(generate_initializesystem(sys); fully_determined = false)
1649+
isys = structural_simplify(
1650+
generate_initializesystem(sys; pmap = parammap); fully_determined = false)
16091651
else
16101652
isys = structural_simplify(
1611-
generate_initializesystem(sys; u0map); fully_determined = false)
1653+
generate_initializesystem(sys; u0map, pmap = parammap); fully_determined = false)
16121654
end
16131655

16141656
uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])
@@ -1628,10 +1670,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
16281670
if warn_initialize_determined && neqs < nunknown
16291671
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
16301672
end
1631-
1632-
parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
1633-
[get_iv(sys) => t] :
1634-
merge(todict(parammap), Dict(get_iv(sys) => t))
1673+
parammap[get_iv(sys)] = t
16351674
if isempty(u0map)
16361675
u0map = Dict()
16371676
end

0 commit comments

Comments
 (0)