Skip to content

Commit 219aee3

Browse files
refactor: add guesses to SDESystem, NonlinearSystem, JumpSystem
1 parent 4792360 commit 219aee3

File tree

8 files changed

+187
-74
lines changed

8 files changed

+187
-74
lines changed

src/systems/diffeqs/odesystem.jl

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -256,29 +256,16 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
256256
:ODESystem, force = true)
257257
end
258258
defaults = Dict{Any, Any}(todict(defaults))
259+
guesses = Dict{Any, Any}(todict(guesses))
259260
var_to_name = Dict()
260-
process_variables!(var_to_name, defaults, dvs′)
261-
process_variables!(var_to_name, defaults, ps′)
262-
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
263-
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
261+
process_variables!(var_to_name, defaults, guesses, dvs′)
262+
process_variables!(var_to_name, defaults, guesses, ps′)
263+
process_variables!(
264+
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
265+
process_variables!(
266+
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
264267
defaults = Dict{Any, Any}(value(k) => value(v)
265268
for (k, v) in pairs(defaults) if v !== nothing)
266-
267-
sysdvsguesses = [ModelingToolkit.getguess(st) for st in dvs′]
268-
hasaguess = findall(!isnothing, sysdvsguesses)
269-
var_guesses = dvs′[hasaguess] .=> sysdvsguesses[hasaguess]
270-
sysdvsguesses = isempty(var_guesses) ? Dict() : todict(var_guesses)
271-
syspsguesses = [ModelingToolkit.getguess(st) for st in ps′]
272-
hasaguess = findall(!isnothing, syspsguesses)
273-
ps_guesses = ps′[hasaguess] .=> syspsguesses[hasaguess]
274-
syspsguesses = isempty(ps_guesses) ? Dict() : todict(ps_guesses)
275-
syspdepguesses = [ModelingToolkit.getguess(eq.lhs) for eq in parameter_dependencies]
276-
hasaguess = findall(!isnothing, syspdepguesses)
277-
pdep_guesses = [eq.lhs for eq in parameter_dependencies][hasaguess] .=>
278-
syspdepguesses[hasaguess]
279-
syspdepguesses = isempty(pdep_guesses) ? Dict() : todict(pdep_guesses)
280-
281-
guesses = merge(sysdvsguesses, syspsguesses, syspdepguesses, todict(guesses))
282269
guesses = Dict{Any, Any}(value(k) => value(v)
283270
for (k, v) in pairs(guesses) if v !== nothing)
284271

src/systems/diffeqs/sdesystem.jl

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@ struct SDESystem <: AbstractODESystem
9393
"""
9494
defaults::Dict
9595
"""
96+
The guesses to use as the initial conditions for the
97+
initialization system.
98+
"""
99+
guesses::Dict
100+
"""
101+
The system for performing the initialization.
102+
"""
103+
initializesystem::Union{Nothing, NonlinearSystem}
104+
"""
105+
Extra equations to be enforced during the initialization sequence.
106+
"""
107+
initialization_eqs::Vector{Equation}
108+
"""
96109
Type of the system.
97110
"""
98111
connector_type::Any
@@ -144,9 +157,8 @@ struct SDESystem <: AbstractODESystem
144157
isscheduled::Bool
145158

146159
function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed,
147-
tgrad,
148-
jac,
149-
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
160+
tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults,
161+
guesses, initializesystem, initialization_eqs, connector_type,
150162
cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing,
151163
complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false,
152164
is_dde = false,
@@ -171,9 +183,9 @@ struct SDESystem <: AbstractODESystem
171183
check_units(u, deqs, neqs)
172184
end
173185
new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
174-
ctrl_jac,
175-
Wfact, Wfact_t, name, description, systems,
176-
defaults, connector_type, cevents, devents,
186+
ctrl_jac, Wfact, Wfact_t, name, description, systems,
187+
defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents,
188+
devents,
177189
parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise,
178190
is_dde, isscheduled)
179191
end
@@ -187,6 +199,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
187199
default_u0 = Dict(),
188200
default_p = Dict(),
189201
defaults = _merge(Dict(default_u0), Dict(default_p)),
202+
guesses = Dict(),
203+
initializesystem = nothing,
204+
initialization_eqs = Equation[],
190205
name = nothing,
191206
description = "",
192207
connector_type = nothing,
@@ -207,6 +222,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
207222
dvs′ = value.(dvs)
208223
ps′ = value.(ps)
209224
ctrl′ = value.(controls)
225+
parameter_dependencies, ps′ = process_parameter_dependencies(
226+
parameter_dependencies, ps′)
210227

211228
sysnames = nameof.(systems)
212229
if length(unique(sysnames)) != length(sysnames)
@@ -217,13 +234,21 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
217234
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
218235
:SDESystem, force = true)
219236
end
220-
defaults = todict(defaults)
221-
defaults = Dict(value(k) => value(v)
222-
for (k, v) in pairs(defaults) if value(v) !== nothing)
223237

238+
defaults = Dict{Any, Any}(todict(defaults))
239+
guesses = Dict{Any, Any}(todict(guesses))
224240
var_to_name = Dict()
225-
process_variables!(var_to_name, defaults, dvs′)
226-
process_variables!(var_to_name, defaults, ps′)
241+
process_variables!(var_to_name, defaults, guesses, dvs′)
242+
process_variables!(var_to_name, defaults, guesses, ps′)
243+
process_variables!(
244+
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
245+
process_variables!(
246+
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
247+
defaults = Dict{Any, Any}(value(k) => value(v)
248+
for (k, v) in pairs(defaults) if v !== nothing)
249+
guesses = Dict{Any, Any}(value(k) => value(v)
250+
for (k, v) in pairs(guesses) if v !== nothing)
251+
227252
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
228253

229254
tgrad = RefValue(EMPTY_TGRAD)
@@ -233,14 +258,13 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
233258
Wfact_t = RefValue(EMPTY_JAC)
234259
cont_callbacks = SymbolicContinuousCallbacks(continuous_events)
235260
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
236-
parameter_dependencies, ps′ = process_parameter_dependencies(
237-
parameter_dependencies, ps′)
238261
if is_dde === nothing
239262
is_dde = _check_if_dde(deqs, iv′, systems)
240263
end
241264
SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
242265
deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
243-
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type,
266+
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
267+
initializesystem, initialization_eqs, connector_type,
244268
cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata,
245269
complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks)
246270
end

src/systems/discrete_system/discrete_system.jl

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
5555
"""
5656
defaults::Dict
5757
"""
58+
The guesses to use as the initial conditions for the
59+
initialization system.
60+
"""
61+
guesses::Dict
62+
"""
63+
The system for performing the initialization.
64+
"""
65+
initializesystem::Union{Nothing, NonlinearSystem}
66+
"""
67+
Extra equations to be enforced during the initialization sequence.
68+
"""
69+
initialization_eqs::Vector{Equation}
70+
"""
5871
Inject assignment statements before the evaluation of the RHS function.
5972
"""
6073
preface::Any
@@ -98,9 +111,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
98111
isscheduled::Bool
99112

100113
function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name,
101-
observed,
102-
name, description,
103-
systems, defaults, preface, connector_type, parameter_dependencies = Equation[],
114+
observed, name, description, systems, defaults, guesses, initializesystem,
115+
initialization_eqs, preface, connector_type, parameter_dependencies = Equation[],
104116
metadata = nothing, gui_metadata = nothing,
105117
tearing_state = nothing, substitutions = nothing,
106118
complete = false, index_cache = nothing, parent = nothing,
@@ -116,8 +128,7 @@ struct DiscreteSystem <: AbstractTimeDependentSystem
116128
check_units(u, discreteEqs)
117129
end
118130
new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, observed, name, description,
119-
systems,
120-
defaults,
131+
systems, defaults, guesses, initializesystem, initialization_eqs,
121132
preface, connector_type, parameter_dependencies, metadata, gui_metadata,
122133
tearing_state, substitutions, complete, index_cache, parent, isscheduled)
123134
end
@@ -135,6 +146,9 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
135146
description = "",
136147
default_u0 = Dict(),
137148
default_p = Dict(),
149+
guesses = Dict(),
150+
initializesystem = nothing,
151+
initialization_eqs = Equation[],
138152
defaults = _merge(Dict(default_u0), Dict(default_p)),
139153
preface = nothing,
140154
connector_type = nothing,
@@ -155,13 +169,21 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
155169
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
156170
:DiscreteSystem, force = true)
157171
end
158-
defaults = todict(defaults)
159-
defaults = Dict(value(k) => value(v)
160-
for (k, v) in pairs(defaults) if value(v) !== nothing)
161172

173+
defaults = Dict{Any, Any}(todict(defaults))
174+
guesses = Dict{Any, Any}(todict(guesses))
162175
var_to_name = Dict()
163-
process_variables!(var_to_name, defaults, dvs′)
164-
process_variables!(var_to_name, defaults, ps′)
176+
process_variables!(var_to_name, defaults, guesses, dvs′)
177+
process_variables!(var_to_name, defaults, guesses, ps′)
178+
process_variables!(
179+
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
180+
process_variables!(
181+
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
182+
defaults = Dict{Any, Any}(value(k) => value(v)
183+
for (k, v) in pairs(defaults) if v !== nothing)
184+
guesses = Dict{Any, Any}(value(k) => value(v)
185+
for (k, v) in pairs(guesses) if v !== nothing)
186+
165187
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
166188

167189
sysnames = nameof.(systems)
@@ -170,7 +192,8 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
170192
end
171193
DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
172194
eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems,
173-
defaults, preface, connector_type, parameter_dependencies, metadata, gui_metadata, kwargs...)
195+
defaults, guesses, initializesystem, initialization_eqs, preface, connector_type,
196+
parameter_dependencies, metadata, gui_metadata, kwargs...)
174197
end
175198

176199
function DiscreteSystem(eqs, iv; kwargs...)
@@ -225,6 +248,8 @@ function flatten(sys::DiscreteSystem, noeqs = false)
225248
parameters(sys),
226249
observed = observed(sys),
227250
defaults = defaults(sys),
251+
guesses = guesses(sys),
252+
initialization_eqs = initialization_equations(sys),
228253
name = nameof(sys),
229254
description = description(sys),
230255
metadata = get_metadata(sys),

src/systems/jumps/jumpsystem.jl

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
8484
"""
8585
defaults::Dict
8686
"""
87+
The guesses to use as the initial conditions for the
88+
initialization system.
89+
"""
90+
guesses::Dict
91+
"""
92+
The system for performing the initialization.
93+
"""
94+
initializesystem::Union{Nothing, NonlinearSystem}
95+
"""
96+
Extra equations to be enforced during the initialization sequence.
97+
"""
98+
initialization_eqs::Vector{Equation}
99+
"""
87100
Type of the system.
88101
"""
89102
connector_type::Any
@@ -125,8 +138,9 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
125138

126139
function JumpSystem{U}(
127140
tag, ap::U, iv, unknowns, ps, var_to_name, observed, name, description,
128-
systems, defaults, connector_type, cevents, devents, parameter_dependencies,
129-
metadata = nothing, gui_metadata = nothing,
141+
systems, defaults, guesses, initializesystem, initialization_eqs, connector_type,
142+
cevents, devents,
143+
parameter_dependencies, metadata = nothing, gui_metadata = nothing,
130144
complete = false, index_cache = nothing, isscheduled = false;
131145
checks::Union{Bool, Int} = true) where {U <: ArrayPartition}
132146
if checks == true || (checks & CheckComponents) > 0
@@ -139,7 +153,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
139153
check_units(u, ap, iv)
140154
end
141155
new{U}(tag, ap, iv, unknowns, ps, var_to_name,
142-
observed, name, description, systems, defaults,
156+
observed, name, description, systems, defaults, guesses, initializesystem,
157+
initialization_eqs,
143158
connector_type, cevents, devents, parameter_dependencies, metadata,
144159
gui_metadata, complete, index_cache, isscheduled)
145160
end
@@ -154,6 +169,9 @@ function JumpSystem(eqs, iv, unknowns, ps;
154169
default_u0 = Dict(),
155170
default_p = Dict(),
156171
defaults = _merge(Dict(default_u0), Dict(default_p)),
172+
guesses = Dict(),
173+
initializesystem = nothing,
174+
initialization_eqs = Equation[],
157175
name = nothing,
158176
description = "",
159177
connector_type = nothing,
@@ -179,13 +197,17 @@ function JumpSystem(eqs, iv, unknowns, ps;
179197
:JumpSystem, force = true)
180198
end
181199
defaults = Dict{Any, Any}(todict(defaults))
200+
guesses = Dict{Any, Any}(todict(guesses))
182201
var_to_name = Dict()
183-
process_variables!(var_to_name, defaults, us′)
184-
process_variables!(var_to_name, defaults, ps′)
185-
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
186-
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
202+
process_variables!(var_to_name, defaults, guesses, us′)
203+
process_variables!(var_to_name, defaults, guesses, ps′)
204+
process_variables!(
205+
var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies])
206+
process_variables!(
207+
var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies])
187208
#! format: off
188209
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults) if value(v) !== nothing)
210+
guesses = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(guesses) if v !== nothing)
189211
#! format: on
190212
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
191213

@@ -219,8 +241,9 @@ function JumpSystem(eqs, iv, unknowns, ps;
219241

220242
JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
221243
ap, iv′, us′, ps′, var_to_name, observed, name, description, systems,
222-
defaults, connector_type, cont_callbacks, disc_callbacks, parameter_dependencies,
223-
metadata, gui_metadata, checks = checks)
244+
defaults, guesses, initializesystem, initialization_eqs, connector_type,
245+
cont_callbacks, disc_callbacks,
246+
parameter_dependencies, metadata, gui_metadata, checks = checks)
224247
end
225248

226249
##### MTK dispatches for JumpSystems #####
@@ -494,7 +517,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
494517
if has_equations(sys)
495518
osys = ODESystem(equations(sys).x[4], get_iv(sys), unknowns(sys), parameters(sys);
496519
observed = observed(sys), name = nameof(sys), description = description(sys),
497-
systems = get_systems(sys), defaults = defaults(sys),
520+
systems = get_systems(sys), defaults = defaults(sys), guesses = guesses(sys),
498521
parameter_dependencies = parameter_dependencies(sys),
499522
metadata = get_metadata(sys), gui_metadata = get_gui_metadata(sys))
500523
osys = complete(osys)

0 commit comments

Comments
 (0)