Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 56 additions & 22 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,39 @@ function JumpSystem(eqs, iv, unknowns, ps;
metadata = nothing,
gui_metadata = nothing,
kwargs...)

# variable processing, similar to ODESystem
name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
eqs = scalarize.(eqs)
iv′ = value(iv)
us′ = value.(unknowns)
ps′ = value.(ps)
parameter_dependencies, ps′ = process_parameter_dependencies(
parameter_dependencies, ps′)
if !(isempty(default_u0) && isempty(default_p))
Base.depwarn(
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:JumpSystem, force = true)
end
defaults = Dict{Any,Any}(todict(defaults))
var_to_name = Dict()
process_variables!(var_to_name, defaults, us′)
process_variables!(var_to_name, defaults, ps′)
process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies])
process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies])
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults)
if value(v) !== nothing)
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))

sysnames = nameof.(systems)
if length(unique(sysnames)) != length(sysnames)
throw(ArgumentError("System names must be unique."))
end

# equation processing
# this and the treatment of continuous events are the only part
# unique to JumpSystems
eqs = scalarize.(eqs)
ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
for eq in eqs
if eq isa MassActionJump
Expand All @@ -179,30 +205,40 @@ function JumpSystem(eqs, iv, unknowns, ps;
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
end
end
if !(isempty(default_u0) && isempty(default_p))
Base.depwarn(
"`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:JumpSystem, force = true)
end
defaults = todict(defaults)
defaults = Dict(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

unknowns, ps = value.(unknowns), value.(ps)
var_to_name = Dict()
process_variables!(var_to_name, defaults, unknowns)
process_variables!(var_to_name, defaults, ps)
isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed))
(continuous_events === nothing) ||
error("JumpSystems currently only support discrete events.")
disc_callbacks = SymbolicDiscreteCallbacks(discrete_events)
parameter_dependencies, ps = process_parameter_dependencies(parameter_dependencies, ps)

JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
ap, value(iv), unknowns, ps, var_to_name, observed, name, description, systems,
ap, iv′, us′, ps, var_to_name, observed, name, description, systems,
defaults, connector_type, disc_callbacks, parameter_dependencies,
metadata, gui_metadata, checks = checks)
end

##### MTK dispatches for JumpSystems #####
function collect_vars!(unknowns, parameters, j::MassActionJump, iv; depth = 0,
op = Differential)
collect_vars!(unknowns, parameters, j.scaled_rates, iv; depth, op)
for field in (j.reactant_stoch, j.net_stoch)
for el in field
collect_vars!(unknowns, parameters, el, iv; depth, op)
end
end
return nothing
end

function collect_vars!(unknowns, parameters, j::Union{ConstantRateJump,VariableRateJump},
iv; depth = 0, op = Differential)
collect_vars!(unknowns, parameters, j.rate, iv; depth, op)
for eq in j.affect!
(eq isa Equation) && collect_vars!(unknowns, parameters, eq, iv; depth, op)
end
return nothing
end

##########################################

has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1])
has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2])
has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3])
Expand Down Expand Up @@ -240,9 +276,8 @@ function assemble_vrj(

outputvars = (value(affect.lhs) for affect in vrj.affect!)
outputidxs = [unknowntoid[var] for var in outputvars]
affect = eval_or_rgf(
generate_affect_function(js, vrj.affect!,
outputidxs); eval_expression, eval_module)
affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs);
eval_expression, eval_module)
VariableRateJump(rate, affect)
end

Expand All @@ -269,9 +304,8 @@ function assemble_crj(

outputvars = (value(affect.lhs) for affect in crj.affect!)
outputidxs = [unknowntoid[var] for var in outputvars]
affect = eval_or_rgf(
generate_affect_function(js, crj.affect!,
outputidxs); eval_expression, eval_module)
affect = eval_or_rgf(generate_affect_function(js, crj.affect!, outputidxs);
eval_expression, eval_module)
ConstantRateJump(rate, affect)
end

Expand Down
36 changes: 36 additions & 0 deletions test/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,39 @@ let

@test all(abs.(cmean .- cmean2) .<= 0.05 .* cmean)
end


# collect_vars! tests for jumps
let
@variables x1(t) x2(t) x3(t) x4(t) x5(t)
@parameters p1 p2 p3 p4 p5
j1 = ConstantRateJump(p1, [x1 ~ x1 + 1])
j2 = MassActionJump(p2, [x2 => 1], [x3 => -1])
j3 = VariableRateJump(p3, [x3 ~ x3 + 1, x4 ~ x4 + 1])
j4 = MassActionJump(p4*p5, [x1 => 1, x5 => 1], [x1 => -1, x5 => -1, x2 => 1])
us = Set()
ps = Set()
iv = t

MT.collect_vars!(us, ps, j1, iv)
@test issetequal(us, [x1])
@test issetequal(ps, [p1])

empty!(us)
empty!(ps)
MT.collect_vars!(us, ps, j2, iv)
@test issetequal(us, [x2, x3])
@test issetequal(ps, [p2])

empty!(us)
empty!(ps)
MT.collect_vars!(us, ps, j3, iv)
@test issetequal(us, [x3, x4])
@test issetequal(ps, [p3])

empty!(us)
empty!(ps)
MT.collect_vars!(us, ps, j4, iv)
@test issetequal(us, [x1, x5, x2])
@test issetequal(ps, [p4, p5])
end
Loading