diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index c53c900d27..e565bb27e1 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -923,15 +923,14 @@ One property to note is that if a system is complete, the system will no longer namespace its subsystems or variables, i.e. `isequal(complete(sys).v.i, v.i)`. """ function complete(sys::AbstractSystem; split = true, flatten = true) - if !(sys isa JumpSystem) - newunknowns = OrderedSet() - newparams = OrderedSet() - iv = has_iv(sys) ? get_iv(sys) : nothing - collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1) - # don't update unknowns to not disturb `structural_simplify` order - # `GlobalScope`d unknowns will be picked up and added there - @set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams))) - end + newunknowns = OrderedSet() + newparams = OrderedSet() + iv = has_iv(sys) ? get_iv(sys) : nothing + collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1) + # don't update unknowns to not disturb `structural_simplify` order + # `GlobalScope`d unknowns will be picked up and added there + @set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams))) + if flatten eqs = equations(sys) if eqs isa AbstractArray && eltype(eqs) <: Equation diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index 6409609a10..9fa073b4b0 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -160,13 +160,40 @@ function JumpSystem(eqs, iv, unknowns, ps; metadata = nothing, gui_metadata = nothing, kwargs...) + + # variable processing, similar to ODESystem name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) - eqs = scalarize.(eqs) + iv′ = value(iv) + us′ = value.(unknowns) + ps′ = value.(ps) + parameter_dependencies, ps′ = process_parameter_dependencies( + parameter_dependencies, ps′) + if !(isempty(default_u0) && isempty(default_p)) + Base.depwarn( + "`default_u0` and `default_p` are deprecated. Use `defaults` instead.", + :JumpSystem, force = true) + end + defaults = Dict{Any, Any}(todict(defaults)) + var_to_name = Dict() + process_variables!(var_to_name, defaults, us′) + process_variables!(var_to_name, defaults, ps′) + process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies]) + process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies]) + #! format: off + defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults) if value(v) !== nothing) + #! format: on + isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) + sysnames = nameof.(systems) if length(unique(sysnames)) != length(sysnames) throw(ArgumentError("System names must be unique.")) end + + # equation processing + # this and the treatment of continuous events are the only part + # unique to JumpSystems + eqs = scalarize.(eqs) ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[]) for eq in eqs if eq isa MassActionJump @@ -179,30 +206,42 @@ function JumpSystem(eqs, iv, unknowns, ps; error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.") end end - if !(isempty(default_u0) && isempty(default_p)) - Base.depwarn( - "`default_u0` and `default_p` are deprecated. Use `defaults` instead.", - :JumpSystem, force = true) - end - defaults = todict(defaults) - defaults = Dict(value(k) => value(v) - for (k, v) in pairs(defaults) if value(v) !== nothing) - unknowns, ps = value.(unknowns), value.(ps) - var_to_name = Dict() - process_variables!(var_to_name, defaults, unknowns) - process_variables!(var_to_name, defaults, ps) - isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) (continuous_events === nothing) || error("JumpSystems currently only support discrete events.") disc_callbacks = SymbolicDiscreteCallbacks(discrete_events) - parameter_dependencies, ps = process_parameter_dependencies(parameter_dependencies, ps) + JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), - ap, value(iv), unknowns, ps, var_to_name, observed, name, description, systems, + ap, iv′, us′, ps′, var_to_name, observed, name, description, systems, defaults, connector_type, disc_callbacks, parameter_dependencies, metadata, gui_metadata, checks = checks) end +##### MTK dispatches for JumpSystems ##### +eqtype_supports_collect_vars(j::MassActionJump) = true +function collect_vars!(unknowns, parameters, j::MassActionJump, iv; depth = 0, + op = Differential) + collect_vars!(unknowns, parameters, j.scaled_rates, iv; depth, op) + for field in (j.reactant_stoch, j.net_stoch) + for el in field + collect_vars!(unknowns, parameters, el, iv; depth, op) + end + end + return nothing +end + +eqtype_supports_collect_vars(j::Union{ConstantRateJump, VariableRateJump}) = true +function collect_vars!(unknowns, parameters, j::Union{ConstantRateJump, VariableRateJump}, + iv; depth = 0, op = Differential) + collect_vars!(unknowns, parameters, j.rate, iv; depth, op) + for eq in j.affect! + (eq isa Equation) && collect_vars!(unknowns, parameters, eq, iv; depth, op) + end + return nothing +end + +########################################## + has_massactionjumps(js::JumpSystem) = !isempty(equations(js).x[1]) has_constantratejumps(js::JumpSystem) = !isempty(equations(js).x[2]) has_variableratejumps(js::JumpSystem) = !isempty(equations(js).x[3]) @@ -240,9 +279,8 @@ function assemble_vrj( outputvars = (value(affect.lhs) for affect in vrj.affect!) outputidxs = [unknowntoid[var] for var in outputvars] - affect = eval_or_rgf( - generate_affect_function(js, vrj.affect!, - outputidxs); eval_expression, eval_module) + affect = eval_or_rgf(generate_affect_function(js, vrj.affect!, outputidxs); + eval_expression, eval_module) VariableRateJump(rate, affect) end @@ -269,9 +307,8 @@ function assemble_crj( outputvars = (value(affect.lhs) for affect in crj.affect!) outputidxs = [unknowntoid[var] for var in outputvars] - affect = eval_or_rgf( - generate_affect_function(js, crj.affect!, - outputidxs); eval_expression, eval_module) + affect = eval_or_rgf(generate_affect_function(js, crj.affect!, outputidxs); + eval_expression, eval_module) ConstantRateJump(rate, affect) end diff --git a/test/jumpsystem.jl b/test/jumpsystem.jl index 24b9abbc8a..30cb636226 100644 --- a/test/jumpsystem.jl +++ b/test/jumpsystem.jl @@ -340,3 +340,85 @@ let @test all(abs.(cmean .- cmean2) .<= 0.05 .* cmean) end + +# collect_vars! tests for jumps +let + @variables x1(t) x2(t) x3(t) x4(t) x5(t) + @parameters p1 p2 p3 p4 p5 + j1 = ConstantRateJump(p1, [x1 ~ x1 + 1]) + j2 = MassActionJump(p2, [x2 => 1], [x3 => -1]) + j3 = VariableRateJump(p3, [x3 ~ x3 + 1, x4 ~ x4 + 1]) + j4 = MassActionJump(p4 * p5, [x1 => 1, x5 => 1], [x1 => -1, x5 => -1, x2 => 1]) + us = Set() + ps = Set() + iv = t + + MT.collect_vars!(us, ps, j1, iv) + @test issetequal(us, [x1]) + @test issetequal(ps, [p1]) + + empty!(us) + empty!(ps) + MT.collect_vars!(us, ps, j2, iv) + @test issetequal(us, [x2, x3]) + @test issetequal(ps, [p2]) + + empty!(us) + empty!(ps) + MT.collect_vars!(us, ps, j3, iv) + @test issetequal(us, [x3, x4]) + @test issetequal(ps, [p3]) + + empty!(us) + empty!(ps) + MT.collect_vars!(us, ps, j4, iv) + @test issetequal(us, [x1, x5, x2]) + @test issetequal(ps, [p4, p5]) +end + +# scoping tests +let + @variables x1(t) x2(t) x3(t) x4(t) x5(t) + x2 = ParentScope(x2) + x3 = ParentScope(ParentScope(x3)) + x4 = DelayParentScope(x4, 2) + x5 = GlobalScope(x5) + @parameters p1 p2 p3 p4 p5 + p2 = ParentScope(p2) + p3 = ParentScope(ParentScope(p3)) + p4 = DelayParentScope(p4, 2) + p5 = GlobalScope(p5) + + j1 = ConstantRateJump(p1, [x1 ~ x1 + 1]) + j2 = MassActionJump(p2, [x2 => 1], [x3 => -1]) + j3 = VariableRateJump(p3, [x3 ~ x3 + 1, x4 ~ x4 + 1]) + j4 = MassActionJump(p4 * p5, [x1 => 1, x5 => 1], [x1 => -1, x5 => -1, x2 => 1]) + @named js = JumpSystem([j1, j2, j3, j4], t, [x1, x2, x3, x4, x5], [p1, p2, p3, p4, p5]) + + us = Set() + ps = Set() + iv = t + MT.collect_scoped_vars!(us, ps, js, iv) + @test issetequal(us, [x2]) + @test issetequal(ps, [p2]) + + empty!.((us, ps)) + MT.collect_scoped_vars!(us, ps, js, iv; depth = 0) + @test issetequal(us, [x1]) + @test issetequal(ps, [p1]) + + empty!.((us, ps)) + MT.collect_scoped_vars!(us, ps, js, iv; depth = 1) + @test issetequal(us, [x2]) + @test issetequal(ps, [p2]) + + empty!.((us, ps)) + MT.collect_scoped_vars!(us, ps, js, iv; depth = 2) + @test issetequal(us, [x3, x4]) + @test issetequal(ps, [p3, p4]) + + empty!.((us, ps)) + MT.collect_scoped_vars!(us, ps, js, iv; depth = -1) + @test issetequal(us, [x5]) + @test issetequal(ps, [p5]) +end