Skip to content

Commit bf9fec8

Browse files
Merge pull request #3180 from isaacsas/add_odes_to_jumps
JumpSystem cleanup
2 parents 20b29de + e6ebcb9 commit bf9fec8

File tree

3 files changed

+149
-31
lines changed

3 files changed

+149
-31
lines changed

src/systems/abstractsystem.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -923,15 +923,14 @@ One property to note is that if a system is complete, the system will no longer
923923
namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`.
924924
"""
925925
function complete(sys::AbstractSystem; split = true, flatten = true)
926-
if !(sys isa JumpSystem)
927-
newunknowns = OrderedSet()
928-
newparams = OrderedSet()
929-
iv = has_iv(sys) ? get_iv(sys) : nothing
930-
collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1)
931-
# don't update unknowns to not disturb `structural_simplify` order
932-
# `GlobalScope`d unknowns will be picked up and added there
933-
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
934-
end
926+
newunknowns = OrderedSet()
927+
newparams = OrderedSet()
928+
iv = has_iv(sys) ? get_iv(sys) : nothing
929+
collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1)
930+
# don't update unknowns to not disturb `structural_simplify` order
931+
# `GlobalScope`d unknowns will be picked up and added there
932+
@set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams)))
933+
935934
if flatten
936935
eqs = equations(sys)
937936
if eqs isa AbstractArray && eltype(eqs) <: Equation

src/systems/jumps/jumpsystem.jl

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,40 @@ function JumpSystem(eqs, iv, unknowns, ps;
160160
metadata = nothing,
161161
gui_metadata = nothing,
162162
kwargs...)
163+
164+
# variable processing, similar to ODESystem
163165
name === nothing &&
164166
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
165-
eqs = scalarize.(eqs)
167+
iv′ = value(iv)
168+
us′ = value.(unknowns)
169+
ps′ = value.(ps)
170+
parameter_dependencies, ps′ = process_parameter_dependencies(
171+
parameter_dependencies, ps′)
172+
if !(isempty(default_u0) && isempty(default_p))
173+
Base.depwarn(
174+
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
175+
:JumpSystem, force = true)
176+
end
177+
defaults = Dict{Any, Any}(todict(defaults))
178+
var_to_name = Dict()
179+
process_variables!(var_to_name, defaults, us′)
180+
process_variables!(var_to_name, defaults, ps′)
181+
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
182+
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
183+
#! format: off
184+
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults) if value(v) !== nothing)
185+
#! format: on
186+
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
187+
166188
sysnames = nameof.(systems)
167189
if length(unique(sysnames)) != length(sysnames)
168190
throw(ArgumentError("System names must be unique."))
169191
end
192+
193+
# equation processing
194+
# this and the treatment of continuous events are the only part
195+
# unique to JumpSystems
196+
eqs = scalarize.(eqs)
170197
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
171198
for eq in eqs
172199
if eq isa MassActionJump
@@ -179,30 +206,42 @@ function JumpSystem(eqs, iv, unknowns, ps;
179206
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
180207
end
181208
end
182-
if !(isempty(default_u0) && isempty(default_p))
183-
Base.depwarn(
184-
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
185-
:JumpSystem, force = true)
186-
end
187-
defaults = todict(defaults)
188-
defaults = Dict(value(k) => value(v)
189-
for (k, v) in pairs(defaults) if value(v) !== nothing)
190209

191-
unknowns, ps = value.(unknowns), value.(ps)
192-
var_to_name = Dict()
193-
process_variables!(var_to_name, defaults, unknowns)
194-
process_variables!(var_to_name, defaults, ps)
195-
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
196210
(continuous_events === nothing) ||
197211
error("JumpSystems currently only support discrete events.")
198212
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
199-
parameter_dependencies, ps = process_parameter_dependencies(parameter_dependencies, ps)
213+
200214
JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
201-
ap, value(iv), unknowns, ps, var_to_name, observed, name, description, systems,
215+
ap, iv′, us′, ps, var_to_name, observed, name, description, systems,
202216
defaults, connector_type, disc_callbacks, parameter_dependencies,
203217
metadata, gui_metadata, checks = checks)
204218
end
205219

220+
##### MTK dispatches for JumpSystems #####
221+
eqtype_supports_collect_vars(j::MassActionJump) = true
222+
function collect_vars!(unknowns, parameters, j::MassActionJump, iv; depth = 0,
223+
op = Differential)
224+
collect_vars!(unknowns, parameters, j.scaled_rates, iv; depth, op)
225+
for field in (j.reactant_stoch, j.net_stoch)
226+
for el in field
227+
collect_vars!(unknowns, parameters, el, iv; depth, op)
228+
end
229+
end
230+
return nothing
231+
end
232+
233+
eqtype_supports_collect_vars(j::Union{ConstantRateJump, VariableRateJump}) = true
234+
function collect_vars!(unknowns, parameters, j::Union{ConstantRateJump, VariableRateJump},
235+
iv; depth = 0, op = Differential)
236+
collect_vars!(unknowns, parameters, j.rate, iv; depth, op)
237+
for eq in j.affect!
238+
(eq isa Equation) && collect_vars!(unknowns, parameters, eq, iv; depth, op)
239+
end
240+
return nothing
241+
end
242+
243+
##########################################
244+
206245
has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
207246
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
208247
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])
@@ -240,9 +279,8 @@ function assemble_vrj(
240279

241280
outputvars = (value(affect.lhs) for affect in vrj.affect!)
242281
outputidxs = [unknowntoid[var] for var in outputvars]
243-
affect = eval_or_rgf(
244-
generate_affect_function(js, vrj.affect!,
245-
outputidxs); eval_expression, eval_module)
282+
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
283+
eval_expression, eval_module)
246284
VariableRateJump(rate, affect)
247285
end
248286

@@ -269,9 +307,8 @@ function assemble_crj(
269307

270308
outputvars = (value(affect.lhs) for affect in crj.affect!)
271309
outputidxs = [unknowntoid[var] for var in outputvars]
272-
affect = eval_or_rgf(
273-
generate_affect_function(js, crj.affect!,
274-
outputidxs); eval_expression, eval_module)
310+
affect = eval_or_rgf(generate_affect_function(js, crj.affect!, outputidxs);
311+
eval_expression, eval_module)
275312
ConstantRateJump(rate, affect)
276313
end
277314

test/jumpsystem.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,85 @@ let
340340

341341
@test all(abs.(cmean .- cmean2) .<= 0.05 .* cmean)
342342
end
343+
344+
# collect_vars! tests for jumps
345+
let
346+
@variables x1(t) x2(t) x3(t) x4(t) x5(t)
347+
@parameters p1 p2 p3 p4 p5
348+
j1 = ConstantRateJump(p1, [x1 ~ x1 + 1])
349+
j2 = MassActionJump(p2, [x2 => 1], [x3 => -1])
350+
j3 = VariableRateJump(p3, [x3 ~ x3 + 1, x4 ~ x4 + 1])
351+
j4 = MassActionJump(p4 * p5, [x1 => 1, x5 => 1], [x1 => -1, x5 => -1, x2 => 1])
352+
us = Set()
353+
ps = Set()
354+
iv = t
355+
356+
MT.collect_vars!(us, ps, j1, iv)
357+
@test issetequal(us, [x1])
358+
@test issetequal(ps, [p1])
359+
360+
empty!(us)
361+
empty!(ps)
362+
MT.collect_vars!(us, ps, j2, iv)
363+
@test issetequal(us, [x2, x3])
364+
@test issetequal(ps, [p2])
365+
366+
empty!(us)
367+
empty!(ps)
368+
MT.collect_vars!(us, ps, j3, iv)
369+
@test issetequal(us, [x3, x4])
370+
@test issetequal(ps, [p3])
371+
372+
empty!(us)
373+
empty!(ps)
374+
MT.collect_vars!(us, ps, j4, iv)
375+
@test issetequal(us, [x1, x5, x2])
376+
@test issetequal(ps, [p4, p5])
377+
end
378+
379+
# scoping tests
380+
let
381+
@variables x1(t) x2(t) x3(t) x4(t) x5(t)
382+
x2 = ParentScope(x2)
383+
x3 = ParentScope(ParentScope(x3))
384+
x4 = DelayParentScope(x4, 2)
385+
x5 = GlobalScope(x5)
386+
@parameters p1 p2 p3 p4 p5
387+
p2 = ParentScope(p2)
388+
p3 = ParentScope(ParentScope(p3))
389+
p4 = DelayParentScope(p4, 2)
390+
p5 = GlobalScope(p5)
391+
392+
j1 = ConstantRateJump(p1, [x1 ~ x1 + 1])
393+
j2 = MassActionJump(p2, [x2 => 1], [x3 => -1])
394+
j3 = VariableRateJump(p3, [x3 ~ x3 + 1, x4 ~ x4 + 1])
395+
j4 = MassActionJump(p4 * p5, [x1 => 1, x5 => 1], [x1 => -1, x5 => -1, x2 => 1])
396+
@named js = JumpSystem([j1, j2, j3, j4], t, [x1, x2, x3, x4, x5], [p1, p2, p3, p4, p5])
397+
398+
us = Set()
399+
ps = Set()
400+
iv = t
401+
MT.collect_scoped_vars!(us, ps, js, iv)
402+
@test issetequal(us, [x2])
403+
@test issetequal(ps, [p2])
404+
405+
empty!.((us, ps))
406+
MT.collect_scoped_vars!(us, ps, js, iv; depth = 0)
407+
@test issetequal(us, [x1])
408+
@test issetequal(ps, [p1])
409+
410+
empty!.((us, ps))
411+
MT.collect_scoped_vars!(us, ps, js, iv; depth = 1)
412+
@test issetequal(us, [x2])
413+
@test issetequal(ps, [p2])
414+
415+
empty!.((us, ps))
416+
MT.collect_scoped_vars!(us, ps, js, iv; depth = 2)
417+
@test issetequal(us, [x3, x4])
418+
@test issetequal(ps, [p3, p4])
419+
420+
empty!.((us, ps))
421+
MT.collect_scoped_vars!(us, ps, js, iv; depth = -1)
422+
@test issetequal(us, [x5])
423+
@test issetequal(ps, [p5])
424+
end

0 commit comments

Comments
 (0)