Skip to content

Commit 2beb33a

Browse files
fixup! feat: support arbitrary systems in generate_initializesystem
1 parent da8afb4 commit 2beb33a

File tree

1 file changed

+42
-35
lines changed

1 file changed

+42
-35
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,16 @@ function generate_initializesystem(sys::AbstractSystem;
2222
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
2323
additional_guesses = anydict(guesses)
2424
guesses = merge(get_guesses(sys), additional_guesses)
25+
idxs_diff = isdiffeq.(eqs)
2526

27+
# 1) Use algebraic equations of time-dependent systems as initialization constraints
2628
if has_iv(sys)
27-
idxs_diff = isdiffeq.(eqs)
2829
idxs_alge = .!idxs_diff
30+
append!(eqs_ics, eqs[idxs_alge]) # start equation list with algebraic equations
31+
end
32+
33+
if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
34+
# 2) process dummy derivatives and u0map into initialization system
2935

3036
# prepare map for dummy derivative substitution
3137
eqs_diff = eqs[idxs_diff]
@@ -35,40 +41,41 @@ function generate_initializesystem(sys::AbstractSystem;
3541
Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs)
3642
)
3743

38-
# 1) process dummy derivatives and u0map into initialization system
39-
append!(eqs_ics, eqs[idxs_alge]) # start equation list with algebraic equations
40-
if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
41-
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
42-
# set dummy derivatives to default_dd_guess unless specified
43-
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
44-
end
45-
function process_u0map_with_dummysubs(y, x)
46-
y = get(schedule.dummy_sub, y, y)
47-
y = fixpoint_sub(y, diffmap)
48-
if y vars_set
49-
# variables specified in u0 overrides defaults
50-
push!(defs, y => x)
51-
elseif y isa Symbolics.Arr
52-
# TODO: don't scalarize arrays
53-
merge!(defs, Dict(scalarize(y .=> x)))
54-
elseif y isa Symbolics.BasicSymbolic
55-
# y is a derivative expression expanded; add it to the initialization equations
56-
push!(eqs_ics, y ~ x)
57-
else
58-
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
59-
end
44+
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
45+
# set dummy derivatives to default_dd_guess unless specified
46+
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
47+
end
48+
function process_u0map_with_dummysubs(y, x)
49+
y = get(schedule.dummy_sub, y, y)
50+
y = fixpoint_sub(y, diffmap)
51+
if y vars_set
52+
# variables specified in u0 overrides defaults
53+
push!(defs, y => x)
54+
elseif y isa Symbolics.Arr
55+
# TODO: don't scalarize arrays
56+
merge!(defs, Dict(scalarize(y .=> x)))
57+
elseif y isa Symbolics.BasicSymbolic
58+
# y is a derivative expression expanded; add it to the initialization equations
59+
push!(eqs_ics, y ~ x)
60+
else
61+
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
6062
end
61-
for (y, x) in u0map
62-
if Symbolics.isarraysymbolic(y)
63-
process_u0map_with_dummysubs.(collect(y), collect(x))
64-
else
65-
process_u0map_with_dummysubs(y, x)
66-
end
63+
end
64+
for (y, x) in u0map
65+
if Symbolics.isarraysymbolic(y)
66+
process_u0map_with_dummysubs.(collect(y), collect(x))
67+
else
68+
process_u0map_with_dummysubs(y, x)
6769
end
6870
end
71+
else
72+
# 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem)
73+
for (k, v) in u0map
74+
defs[k] = v
75+
end
6976
end
7077

71-
# 2) process other variables
78+
# 3) process other variables
7279
for var in vars
7380
if var keys(defs)
7481
push!(eqs_ics, var ~ defs[var])
@@ -79,7 +86,7 @@ function generate_initializesystem(sys::AbstractSystem;
7986
end
8087
end
8188

82-
# 3) process explicitly provided initialization equations
89+
# 4) process explicitly provided initialization equations
8390
if !algebraic_only
8491
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
8592
for eq in initialization_eqs
@@ -88,7 +95,7 @@ function generate_initializesystem(sys::AbstractSystem;
8895
end
8996
end
9097

91-
# 4) process parameters as initialization unknowns
98+
# 5) process parameters as initialization unknowns
9299
paramsubs = Dict()
93100
if pmap isa SciMLBase.NullParameters
94101
pmap = Dict()
@@ -143,7 +150,7 @@ function generate_initializesystem(sys::AbstractSystem;
143150
end
144151
end
145152

146-
# 5) parameter dependencies become equations, their LHS become unknowns
153+
# 6) parameter dependencies become equations, their LHS become unknowns
147154
# non-numeric dependent parameters stay as parameter dependencies
148155
new_parameter_deps = Equation[]
149156
for eq in parameter_dependencies(sys)
@@ -158,7 +165,7 @@ function generate_initializesystem(sys::AbstractSystem;
158165
push!(defs, varp => guessval)
159166
end
160167

161-
# 6) handle values provided for dependent parameters similar to values for observed variables
168+
# 7) handle values provided for dependent parameters similar to values for observed variables
162169
for (k, v) in merge(defaults(sys), pmap)
163170
if is_variable_floatingpoint(k) && has_parameter_dependency_with_lhs(sys, k)
164171
push!(eqs_ics, paramsubs[k] ~ v)
@@ -171,7 +178,7 @@ function generate_initializesystem(sys::AbstractSystem;
171178
[p for p in parameters(sys) if !haskey(paramsubs, p)]
172179
)
173180

174-
# 7) use observed equations for guesses of observed variables if not provided
181+
# 8) use observed equations for guesses of observed variables if not provided
175182
for eq in trueobs
176183
haskey(defs, eq.lhs) && continue
177184
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue

0 commit comments

Comments
 (0)