Skip to content

Commit af8cd67

Browse files
feat: support arbitrary systems in generate_initializesystem
1 parent 219aee3 commit af8cd67

File tree

1 file changed

+38
-27
lines changed

1 file changed

+38
-27
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ $(TYPEDSIGNATURES)
33
44
Generate `NonlinearSystem` which initializes an ODE problem from specified initial conditions of an `ODESystem`.
55
"""
6-
function generate_initializesystem(sys::ODESystem;
6+
function generate_initializesystem(sys::AbstractSystem;
77
u0map = Dict(),
88
pmap = Dict(),
99
initialization_eqs = [],
@@ -12,28 +12,36 @@ function generate_initializesystem(sys::ODESystem;
1212
algebraic_only = false,
1313
check_units = true, check_defguess = false,
1414
name = nameof(sys), extra_metadata = (;), kwargs...)
15-
trueobs, eqs = unhack_observed(observed(sys), equations(sys))
15+
eqs = equations(sys)
16+
eqs = filter(x -> x isa Equation, eqs)
17+
trueobs, eqs = unhack_observed(observed(sys), eqs)
1618
vars = unique([unknowns(sys); getfield.(trueobs, :lhs)])
1719
vars_set = Set(vars) # for efficient in-lookup
1820

19-
idxs_diff = isdiffeq.(eqs)
20-
idxs_alge = .!idxs_diff
21-
22-
# prepare map for dummy derivative substitution
23-
eqs_diff = eqs[idxs_diff]
24-
D = Differential(get_iv(sys))
25-
diffmap = merge(
26-
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
27-
Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs)
28-
)
29-
30-
# 1) process dummy derivatives and u0map into initialization system
31-
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
21+
eqs_ics = Equation[]
3222
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
3323
additional_guesses = anydict(guesses)
3424
guesses = merge(get_guesses(sys), additional_guesses)
35-
schedule = getfield(sys, :schedule)
36-
if !isnothing(schedule)
25+
idxs_diff = isdiffeq.(eqs)
26+
27+
# 1) Use algebraic equations of time-dependent systems as initialization constraints
28+
if has_iv(sys)
29+
idxs_alge = .!idxs_diff
30+
append!(eqs_ics, eqs[idxs_alge]) # start equation list with algebraic equations
31+
32+
eqs_diff = eqs[idxs_diff]
33+
D = Differential(get_iv(sys))
34+
diffmap = merge(
35+
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
36+
Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs)
37+
)
38+
else
39+
diffmap = Dict()
40+
end
41+
42+
if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule))
43+
# 2) process dummy derivatives and u0map into initialization system
44+
# prepare map for dummy derivative substitution
3745
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
3846
# set dummy derivatives to default_dd_guess unless specified
3947
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
@@ -61,9 +69,14 @@ function generate_initializesystem(sys::ODESystem;
6169
process_u0map_with_dummysubs(y, x)
6270
end
6371
end
72+
else
73+
# 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem)
74+
for (k, v) in u0map
75+
defs[k] = v
76+
end
6477
end
6578

66-
# 2) process other variables
79+
# 3) process other variables
6780
for var in vars
6881
if var keys(defs)
6982
push!(eqs_ics, var ~ defs[var])
@@ -74,7 +87,7 @@ function generate_initializesystem(sys::ODESystem;
7487
end
7588
end
7689

77-
# 3) process explicitly provided initialization equations
90+
# 4) process explicitly provided initialization equations
7891
if !algebraic_only
7992
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
8093
for eq in initialization_eqs
@@ -83,7 +96,7 @@ function generate_initializesystem(sys::ODESystem;
8396
end
8497
end
8598

86-
# 4) process parameters as initialization unknowns
99+
# 5) process parameters as initialization unknowns
87100
paramsubs = Dict()
88101
if pmap isa SciMLBase.NullParameters
89102
pmap = Dict()
@@ -138,7 +151,7 @@ function generate_initializesystem(sys::ODESystem;
138151
end
139152
end
140153

141-
# 5) parameter dependencies become equations, their LHS become unknowns
154+
# 6) parameter dependencies become equations, their LHS become unknowns
142155
# non-numeric dependent parameters stay as parameter dependencies
143156
new_parameter_deps = Equation[]
144157
for eq in parameter_dependencies(sys)
@@ -153,20 +166,18 @@ function generate_initializesystem(sys::ODESystem;
153166
push!(defs, varp => guessval)
154167
end
155168

156-
# 6) handle values provided for dependent parameters similar to values for observed variables
169+
# 7) handle values provided for dependent parameters similar to values for observed variables
157170
for (k, v) in merge(defaults(sys), pmap)
158171
if is_variable_floatingpoint(k) && has_parameter_dependency_with_lhs(sys, k)
159172
push!(eqs_ics, paramsubs[k] ~ v)
160173
end
161174
end
162175

163176
# parameters do not include ones that became initialization unknowns
164-
pars = vcat(
165-
[get_iv(sys)], # include independent variable as pseudo-parameter
166-
[p for p in parameters(sys) if !haskey(paramsubs, p)]
167-
)
177+
pars = Vector{SymbolicParam}(filter(p -> !haskey(paramsubs, p), parameters(sys)))
178+
is_time_dependent(sys) && push!(pars, get_iv(sys))
168179

169-
# 7) use observed equations for guesses of observed variables if not provided
180+
# 8) use observed equations for guesses of observed variables if not provided
170181
for eq in trueobs
171182
haskey(defs, eq.lhs) && continue
172183
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue

0 commit comments

Comments
 (0)