Skip to content

Commit 0bbba16

Browse files
fix: only solve parameter initialization for NonlinearSystem
1 parent 4c86290 commit 0bbba16

File tree

4 files changed

+180
-205
lines changed

4 files changed

+180
-205
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1335,7 +1335,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
13351335

13361336
# TODO: throw on uninitialized arrays
13371337
filter!(x -> !(x isa Symbolics.Arr), uninit)
1338-
if !isempty(uninit)
1338+
if is_time_dependent(sys) && !isempty(uninit)
13391339
allow_incomplete || throw(IncompleteInitializationError(uninit))
13401340
# for incomplete initialization, we will add the missing variables as parameters.
13411341
# they will be updated by `update_initializeprob!` and `initializeprobmap` will

src/systems/nonlinear/initializesystem.jl

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,51 +40,53 @@ function generate_initializesystem(sys::AbstractSystem;
4040
diffmap = Dict()
4141
end
4242

43-
if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
44-
# 2) process dummy derivatives and u0map into initialization system
45-
# prepare map for dummy derivative substitution
46-
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
47-
# set dummy derivatives to default_dd_guess unless specified
48-
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
49-
end
50-
function process_u0map_with_dummysubs(y, x)
51-
y = get(schedule.dummy_sub, y, y)
52-
y = fixpoint_sub(y, diffmap)
53-
if y vars_set
54-
# variables specified in u0 overrides defaults
55-
push!(defs, y => x)
56-
elseif y isa Symbolics.Arr
57-
# TODO: don't scalarize arrays
58-
merge!(defs, Dict(scalarize(y .=> x)))
59-
elseif y isa Symbolics.BasicSymbolic
60-
# y is a derivative expression expanded; add it to the initialization equations
61-
push!(eqs_ics, y ~ x)
62-
else
63-
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
43+
if is_time_dependent(sys)
44+
if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
45+
# 2) process dummy derivatives and u0map into initialization system
46+
# prepare map for dummy derivative substitution
47+
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
48+
# set dummy derivatives to default_dd_guess unless specified
49+
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
6450
end
65-
end
66-
for (y, x) in u0map
67-
if Symbolics.isarraysymbolic(y)
68-
process_u0map_with_dummysubs.(collect(y), collect(x))
69-
else
70-
process_u0map_with_dummysubs(y, x)
51+
function process_u0map_with_dummysubs(y, x)
52+
y = get(schedule.dummy_sub, y, y)
53+
y = fixpoint_sub(y, diffmap)
54+
if y vars_set
55+
# variables specified in u0 overrides defaults
56+
push!(defs, y => x)
57+
elseif y isa Symbolics.Arr
58+
# TODO: don't scalarize arrays
59+
merge!(defs, Dict(scalarize(y .=> x)))
60+
elseif y isa Symbolics.BasicSymbolic
61+
# y is a derivative expression expanded; add it to the initialization equations
62+
push!(eqs_ics, y ~ x)
63+
else
64+
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
65+
end
66+
end
67+
for (y, x) in u0map
68+
if Symbolics.isarraysymbolic(y)
69+
process_u0map_with_dummysubs.(collect(y), collect(x))
70+
else
71+
process_u0map_with_dummysubs(y, x)
72+
end
73+
end
74+
else
75+
# 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem)
76+
for (k, v) in u0map
77+
defs[k] = v
7178
end
7279
end
73-
else
74-
# 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem)
75-
for (k, v) in u0map
76-
defs[k] = v
77-
end
78-
end
7980

80-
# 3) process other variables
81-
for var in vars
82-
if var keys(defs)
83-
push!(eqs_ics, var ~ defs[var])
84-
elseif var keys(guesses)
85-
push!(defs, var => guesses[var])
86-
elseif check_defguess
87-
error("Invalid setup: variable $(var) has no default value or initial guess")
81+
# 3) process other variables
82+
for var in vars
83+
if var keys(defs)
84+
push!(eqs_ics, var ~ defs[var])
85+
elseif var keys(guesses)
86+
push!(defs, var => guesses[var])
87+
elseif check_defguess
88+
error("Invalid setup: variable $(var) has no default value or initial guess")
89+
end
8890
end
8991
end
9092

@@ -178,16 +180,24 @@ function generate_initializesystem(sys::AbstractSystem;
178180
pars = Vector{SymbolicParam}(filter(p -> !haskey(paramsubs, p), parameters(sys)))
179181
is_time_dependent(sys) && push!(pars, get_iv(sys))
180182

181-
# 8) use observed equations for guesses of observed variables if not provided
182-
for eq in trueobs
183-
haskey(defs, eq.lhs) && continue
184-
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue
183+
if is_time_dependent(sys)
184+
# 8) use observed equations for guesses of observed variables if not provided
185+
for eq in trueobs
186+
haskey(defs, eq.lhs) && continue
187+
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue
185188

186-
defs[eq.lhs] = eq.rhs
189+
defs[eq.lhs] = eq.rhs
190+
end
191+
append!(eqs_ics, trueobs)
192+
end
193+
194+
eqs_ics = Symbolics.substitute.(eqs_ics, (paramsubs,))
195+
if is_time_dependent(sys)
196+
vars = [vars; collect(values(paramsubs))]
197+
else
198+
vars = collect(values(paramsubs))
187199
end
188200

189-
eqs_ics = Symbolics.substitute.([eqs_ics; trueobs], (paramsubs,))
190-
vars = [vars; collect(values(paramsubs))]
191201
for k in keys(defs)
192202
defs[k] = substitute(defs[k], paramsubs)
193203
end

src/systems/problem_utils.jl

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -546,30 +546,46 @@ function maybe_build_initialization_problem(
546546
initializeprob = ModelingToolkit.InitializationProblem(
547547
sys, t, u0map, pmap; guesses, kwargs...)
548548

549-
all_init_syms = Set(all_symbols(initializeprob))
550-
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
551-
initializeprobmap = getu(initializeprob, solved_unknowns)
549+
if is_time_dependent(sys)
550+
all_init_syms = Set(all_symbols(initializeprob))
551+
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
552+
initializeprobmap = getu(initializeprob, solved_unknowns)
553+
else
554+
initializeprobmap = nothing
555+
end
552556

553557
punknowns = [p
554558
for p in all_variable_symbols(initializeprob)
555559
if is_parameter(sys, p)]
556-
getpunknowns = getu(initializeprob, punknowns)
557-
setpunknowns = setp(sys, punknowns)
558-
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
560+
if isempty(punknowns)
561+
initializeprobpmap = nothing
562+
else
563+
getpunknowns = getu(initializeprob, punknowns)
564+
setpunknowns = setp(sys, punknowns)
565+
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
566+
end
559567

560568
reqd_syms = parameter_symbols(initializeprob)
561-
update_initializeprob! = UpdateInitializeprob(
562-
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
569+
# we still want the `initialization_data` because it helps with `remake`
570+
if initializeprobmap === nothing && initializeprobpmap === nothing
571+
update_initializeprob! = nothing
572+
else
573+
update_initializeprob! = UpdateInitializeprob(
574+
getu(sys, reqd_syms), setu(initializeprob, reqd_syms))
575+
end
576+
563577
for p in punknowns
564578
p = unwrap(p)
565579
stype = symtype(p)
566580
op[p] = get_temporary_value(p)
567581
end
568582

569-
for v in missing_unknowns
570-
op[v] = zero_var(v)
583+
if is_time_dependent(sys)
584+
for v in missing_unknowns
585+
op[v] = zero_var(v)
586+
end
587+
empty!(missing_unknowns)
571588
end
572-
empty!(missing_unknowns)
573589
return (;
574590
initialization_data = SciMLBase.OverrideInitData(
575591
initializeprob, update_initializeprob!, initializeprobmap,

0 commit comments

Comments
 (0)