Skip to content

Commit 2c24809

Browse files
committed
cleanup JumpSystem constructor to match ODESystem better
1 parent f19a5bf commit 2c24809

File tree

1 file changed

+57
-24
lines changed

1 file changed

+57
-24
lines changed

src/systems/jumps/jumpsystem.jl

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem
9696
"""
9797
discrete_events::Vector{SymbolicDiscreteCallback}
9898
"""
99+
A `Vector{SymbolicContinuousCallback}` that model events.
100+
The integrator will use root finding to guarantee that it steps at each zero crossing.
101+
"""
102+
continuous_events::Vector{SymbolicContinuousCallback}
103+
"""
99104
Topologically sorted parameter dependency equations, where all symbols are parameters and
100105
the LHS is a single parameter.
101106
"""
@@ -160,13 +165,31 @@ function JumpSystem(eqs, iv, unknowns, ps;
160165
metadata = nothing,
161166
gui_metadata = nothing,
162167
kwargs...)
168+
169+
# variable processing, similar to ODESystem
163170
name === nothing &&
164171
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
165-
eqs = scalarize.(eqs)
166-
sysnames = nameof.(systems)
167-
if length(unique(sysnames)) != length(sysnames)
168-
throw(ArgumentError("System names must be unique."))
172+
iv′ = value(iv)
173+
us′ = value.(unknowns)
174+
ps′ = value.(ps)
175+
parameter_dependencies, ps = process_parameter_dependencies(parameter_dependencies, ps′)
176+
if !(isempty(default_u0) && isempty(default_p))
177+
Base.depwarn(
178+
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
179+
:JumpSystem, force = true)
169180
end
181+
defaults = todict(defaults)
182+
var_to_name = Dict()
183+
process_variables!(var_to_name, defaults, us′)
184+
process_variables!(var_to_name, defaults, ps′)
185+
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
186+
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
187+
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults)
188+
if value(v) !== nothing)
189+
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
190+
191+
# equation processing
192+
eqs = scalarize.(eqs)
170193
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
171194
for eq in eqs
172195
if eq isa MassActionJump
@@ -179,30 +202,42 @@ function JumpSystem(eqs, iv, unknowns, ps;
179202
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
180203
end
181204
end
182-
if !(isempty(default_u0) && isempty(default_p))
183-
Base.depwarn(
184-
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
185-
:JumpSystem, force = true)
205+
206+
sysnames = nameof.(systems)
207+
if length(unique(sysnames)) != length(sysnames)
208+
throw(ArgumentError("System names must be unique."))
186209
end
187-
defaults = todict(defaults)
188-
defaults = Dict(value(k) => value(v)
189-
for (k, v) in pairs(defaults) if value(v) !== nothing)
190210

191-
unknowns, ps = value.(unknowns), value.(ps)
192-
var_to_name = Dict()
193-
process_variables!(var_to_name, defaults, unknowns)
194-
process_variables!(var_to_name, defaults, ps)
195-
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
196211
(continuous_events === nothing) ||
197212
error("JumpSystems currently only support discrete events.")
198213
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
199-
parameter_dependencies, ps = process_parameter_dependencies(parameter_dependencies, ps)
214+
200215
JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
201-
ap, value(iv), unknowns, ps, var_to_name, observed, name, description, systems,
216+
ap, iv′, us′, ps, var_to_name, observed, name, description, systems,
202217
defaults, connector_type, disc_callbacks, parameter_dependencies,
203218
metadata, gui_metadata, checks = checks)
204219
end
205220

221+
##### MTK dispatches for JumpSystems #####
222+
function collect_vars!(unknowns, parameters, j::MassActionJump, iv; depth = 0,
223+
op = Differential)
224+
for field in (j.scaled_rates, j.reactant_stoch, j.net_stoch)
225+
collect_vars!(unknowns, parameters, field, iv; depth, op)
226+
end
227+
return nothing
228+
end
229+
230+
function collect_vars!(unknowns, parameters, j::Union{ConstantRateJump,VariableRateJump},
231+
iv; depth = 0, op = Differential)
232+
collect_vars!(unknowns, parameters, j.condition, iv; depth, op)
233+
for eq in j.affect
234+
(eq isa Equation) && collect_vars!(unknowns, parameters, eq, iv; depth, op)
235+
end
236+
return nothing
237+
end
238+
239+
##########################################
240+
206241
has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
207242
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
208243
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])
@@ -240,9 +275,8 @@ function assemble_vrj(
240275

241276
outputvars = (value(affect.lhs) for affect in vrj.affect!)
242277
outputidxs = [unknowntoid[var] for var in outputvars]
243-
affect = eval_or_rgf(
244-
generate_affect_function(js, vrj.affect!,
245-
outputidxs); eval_expression, eval_module)
278+
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
279+
eval_expression, eval_module)
246280
VariableRateJump(rate, affect)
247281
end
248282

@@ -269,9 +303,8 @@ function assemble_crj(
269303

270304
outputvars = (value(affect.lhs) for affect in crj.affect!)
271305
outputidxs = [unknowntoid[var] for var in outputvars]
272-
affect = eval_or_rgf(
273-
generate_affect_function(js, crj.affect!,
274-
outputidxs); eval_expression, eval_module)
306+
affect = eval_or_rgf(generate_affect_function(js, crj.affect!, outputidxs);
307+
eval_expression, eval_module)
275308
ConstantRateJump(rate, affect)
276309
end
277310

0 commit comments

Comments
 (0)