Skip to content

Commit 1320fc0

Browse files
refactor: separate out operating point and initializeprob construction
1 parent 24d39af commit 1320fc0

File tree

1 file changed

+94
-60
lines changed

1 file changed

+94
-60
lines changed

src/systems/problem_utils.jl

Lines changed: 94 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,89 @@ function EmptySciMLFunction(args...; kwargs...)
489489
return EmptySciMLFunction{typeof(args), typeof(kwargs)}(args, kwargs)
490490
end
491491

492+
"""
493+
$(TYPEDSIGNATURES)
494+
495+
Construct the operating point of the system from the user-provided `u0map` and `pmap`, system
496+
defaults `defs`, constant equations `cmap` (from `get_cmap(sys)`), unknowns `dvs` and
497+
parameters `ps`. Return the operating point as a dictionary, the list of unknowns for which
498+
no values can be determined, and the list of parameters for which no values can be determined.
499+
"""
500+
function build_operating_point(
501+
u0map::AbstractDict, pmap::AbstractDict, defs::AbstractDict, cmap, dvs, ps)
502+
op = add_toterms(u0map)
503+
missing_unknowns = add_fallbacks!(op, dvs, defs)
504+
for (k, v) in defs
505+
haskey(op, k) && continue
506+
op[k] = v
507+
end
508+
merge!(op, pmap)
509+
missing_pars = add_fallbacks!(op, ps, defs)
510+
for eq in cmap
511+
op[eq.lhs] = eq.rhs
512+
end
513+
return op, missing_unknowns, missing_pars
514+
end
515+
516+
"""
517+
$(TYPEDSIGNATURES)
518+
519+
Build and return the initialization problem and associated data as a `NamedTuple` to be passed
520+
to the `SciMLFunction` constructor. Requires the system `sys`, operating point `op`,
521+
user-provided `u0map` and `pmap`, initial time `t`, system defaults `defs`, user-provided
522+
`guesses`, and list of unknowns which don't have a value in `op`. The keyword `implicit_dae`
523+
denotes whether the `SciMLProblem` being constructed is in implicit DAE form (`DAEProblem`).
524+
All other keyword arguments are forwarded to `InitializationProblem`.
525+
"""
526+
function maybe_build_initialization_problem(
527+
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
528+
guesses, missing_unknowns; implicit_dae = false, kwargs...)
529+
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
530+
has_observed_u0s = any(
531+
k -> has_observed_with_lhs(sys, k) || has_parameter_dependency_with_lhs(sys, k),
532+
keys(op))
533+
solvablepars = [p
534+
for p in parameters(sys)
535+
if is_parameter_solvable(p, pmap, defs, guesses)]
536+
has_dependent_unknowns = any(unknowns(sys)) do sym
537+
val = get(op, sym, nothing)
538+
val === nothing && return false
539+
return symbolic_type(val) != NotSymbolic() || is_array_of_symbolics(val)
540+
end
541+
if (((implicit_dae || has_observed_u0s || !isempty(missing_unknowns) ||
542+
!isempty(solvablepars) || has_dependent_unknowns) &&
543+
get_tearing_state(sys) !== nothing) ||
544+
!isempty(initialization_equations(sys))) && t !== nothing
545+
initializeprob = ModelingToolkit.InitializationProblem(
546+
sys, t, u0map, pmap; guesses, kwargs...)
547+
initializeprobmap = getu(initializeprob, unknowns(sys))
548+
549+
punknowns = [p
550+
for p in all_variable_symbols(initializeprob)
551+
if is_parameter(sys, p)]
552+
getpunknowns = getu(initializeprob, punknowns)
553+
setpunknowns = setp(sys, punknowns)
554+
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
555+
556+
reqd_syms = parameter_symbols(initializeprob)
557+
update_initializeprob! = UpdateInitializeprob(
558+
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
559+
for p in punknowns
560+
p = unwrap(p)
561+
stype = symtype(p)
562+
op[p] = get_temporary_value(p)
563+
end
564+
565+
for v in missing_unknowns
566+
op[v] = zero_var(v)
567+
end
568+
empty!(missing_unknowns)
569+
return (;
570+
initializeprob, initializeprobmap, initializeprobpmap, update_initializeprob!)
571+
end
572+
return (;)
573+
end
574+
492575
"""
493576
$(TYPEDSIGNATURES)
494577
@@ -576,67 +659,18 @@ function process_SciMLProblem(
576659
cmap, cs = get_cmap(sys)
577660
kwargs = NamedTuple(kwargs)
578661

579-
op = add_toterms(u0map)
580-
missing_unknowns = add_fallbacks!(op, dvs, defs)
581-
for (k, v) in defs
582-
haskey(op, k) && continue
583-
op[k] = v
584-
end
585-
merge!(op, pmap)
586-
missing_pars = add_fallbacks!(op, ps, defs)
587-
for eq in cmap
588-
op[eq.lhs] = eq.rhs
589-
end
590-
if sys isa ODESystem
591-
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
592-
has_observed_u0s = any(
593-
k -> has_observed_with_lhs(sys, k) || has_parameter_dependency_with_lhs(sys, k),
594-
keys(op))
595-
solvablepars = [p
596-
for p in parameters(sys)
597-
if is_parameter_solvable(p, pmap, defs, guesses)]
598-
has_dependent_unknowns = any(unknowns(sys)) do sym
599-
val = get(op, sym, nothing)
600-
val === nothing && return false
601-
return symbolic_type(val) != NotSymbolic() || is_array_of_symbolics(val)
602-
end
603-
if build_initializeprob &&
604-
(((implicit_dae || has_observed_u0s || !isempty(missing_unknowns) ||
605-
!isempty(solvablepars) || has_dependent_unknowns) &&
606-
get_tearing_state(sys) !== nothing) ||
607-
!isempty(initialization_equations(sys))) && t !== nothing
608-
initializeprob = ModelingToolkit.InitializationProblem(
609-
sys, t, u0map, pmap; guesses, warn_initialize_determined,
610-
initialization_eqs, eval_expression, eval_module, fully_determined,
611-
warn_cyclic_dependency, check_units = check_initialization_units,
612-
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc)
613-
initializeprobmap = getu(initializeprob, unknowns(sys))
614-
615-
punknowns = [p
616-
for p in all_variable_symbols(initializeprob)
617-
if is_parameter(sys, p)]
618-
getpunknowns = getu(initializeprob, punknowns)
619-
setpunknowns = setp(sys, punknowns)
620-
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
621-
622-
reqd_syms = parameter_symbols(initializeprob)
623-
update_initializeprob! = UpdateInitializeprob(
624-
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
625-
for p in punknowns
626-
p = unwrap(p)
627-
stype = symtype(p)
628-
op[p] = get_temporary_value(p)
629-
delete!(missing_pars, p)
630-
end
662+
op, missing_unknowns, missing_pars = build_operating_point(
663+
u0map, pmap, defs, cmap, dvs, ps)
631664

632-
for v in missing_unknowns
633-
op[v] = zero_var(v)
634-
end
635-
empty!(missing_unknowns)
636-
kwargs = merge(kwargs,
637-
(; initializeprob, initializeprobmap,
638-
initializeprobpmap, update_initializeprob!))
639-
end
665+
if sys isa ODESystem && build_initializeprob
666+
kws = maybe_build_initialization_problem(
667+
sys, op, u0map, pmap, t, defs, guesses, missing_unknowns;
668+
implicit_dae, warn_initialize_determined, initialization_eqs,
669+
eval_expression, eval_module, fully_determined,
670+
warn_cyclic_dependency, check_units = check_initialization_units,
671+
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc)
672+
673+
kwargs = merge(kwargs, kws)
640674
end
641675

642676
if t !== nothing && !(constructor <: Union{DDEFunction, SDDEFunction})

0 commit comments

Comments
 (0)