Skip to content

Commit fa5a235

Browse files
committed
Finish events and add test
1 parent c9d98cc commit fa5a235

File tree

4 files changed

+407
-150
lines changed

4 files changed

+407
-150
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ julia = "1.9"
6060

6161
[extras]
6262
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
63+
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
6364
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
6465
Graphviz_jll = "3c863552-8265-54e4-a6dc-903eb78fde85"
6566
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
@@ -80,4 +81,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8081
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
8182

8283
[targets]
83-
test = ["BifurcationKit", "DomainSets", "Graphviz_jll", "HomotopyContinuation", "NonlinearSolve", "OrdinaryDiffEq", "Plots", "Random", "SafeTestsets", "SciMLBase", "SciMLNLSolve", "StableRNGs", "Statistics", "SteadyStateDiffEq", "StochasticDiffEq", "StructuralIdentifiability", "Test", "Unitful"]
84+
test = ["BifurcationKit", "DiffEqCallbacks", "DomainSets", "Graphviz_jll", "HomotopyContinuation", "NonlinearSolve", "OrdinaryDiffEq", "Plots", "Random", "SafeTestsets", "SciMLBase", "SciMLNLSolve", "StableRNGs", "Statistics", "SteadyStateDiffEq", "StochasticDiffEq", "StructuralIdentifiability", "Test", "Unitful"]

src/reactionsystem.jl

Lines changed: 102 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -628,70 +628,77 @@ function ReactionSystem(eqs, iv, unknowns, ps;
628628
continuous_events = nothing,
629629
discrete_events = nothing,
630630
metadata = nothing)
631-
632-
name === nothing &&
631+
632+
# Error checks
633+
if name === nothing &&
633634
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
635+
end
634636
sysnames = nameof.(systems)
635-
(length(unique(sysnames)) == length(sysnames)) ||
636-
throw(ArgumentError("System names must be unique."))
637+
(length(unique(sysnames)) == length(sysnames)) || throw(ArgumentError("System names must be unique."))
637638

639+
# Handle defaults values provided via optional arguments.
638640
if !(isempty(default_u0) && isempty(default_p))
639-
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
640-
:ReactionSystem, force = true)
641+
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :ReactionSystem, force = true)
641642
end
642643
defaults = MT.todict(defaults)
643644
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults))
644645

646+
# Extracts independent variables (iv and sivs), dependent variables (species and variables)
647+
# and parameters. Sorts so that species comes before variables in unknowns vector.
645648
iv′ = value(iv)
646649
sivs′ = if spatial_ivs === nothing
647650
Vector{typeof(iv′)}()
648651
else
649652
value.(MT.scalarize(spatial_ivs))
650653
end
651-
unknowns′ = sort!(value.(MT.scalarize(unknowns)), by = !isspecies) # species come first
654+
unknowns′ = sort!(value.(MT.scalarize(unknowns)), by = !isspecies)
652655
spcs = filter(isspecies, unknowns′)
653656
ps′ = value.(MT.scalarize(ps))
654657

658+
# Checks that no (by Catalyst) forbidden symbols are used.
655659
allsyms = Iterators.flatten((ps′, unknowns′))
656-
all(sym -> getname(sym) forbidden_symbols_error, allsyms) ||
660+
if !all(sym -> getname(sym) forbidden_symbols_error, allsyms)
657661
error("Catalyst reserves the symbols $forbidden_symbols_error for internal use. Please do not use these symbols as parameters or unknowns/species.")
662+
end
658663

659-
# sort Reactions before Equations
664+
# Handles reactions and equations. Sorts so that reactions are before equaions in the equations vector.
660665
eqs′ = CatalystEqType[eq for eq in eqs]
661666
sort!(eqs′; by = eqsortby)
662667
rxs = Reaction[rx for rx in eqs if rx isa Reaction]
663668

669+
# Additional error checks.
664670
if any(MT.isparameter, unknowns′)
665671
psts = filter(MT.isparameter, unknowns′)
666672
throw(ArgumentError("Found one or more parameters among the unknowns; this is not allowed. Move: $psts to be parameters."))
667673
end
668-
669674
if any(isconstant, unknowns′)
670675
csts = filter(isconstant, unknowns′)
671676
throw(ArgumentError("Found one or more constant species among the unknowns; this is not allowed. Move: $csts to be parameters."))
672677
end
673-
674-
# if there are BC species, check they are balanced in their reactions
678+
# If there are BC species, check they are balanced in their reactions.
675679
if balanced_bc_check && any(isbc, unknowns′)
676680
for rx in eqs
677-
if rx isa Reaction
678-
isbcbalanced(rx) ||
679-
throw(ErrorException("BC species must be balanced, appearing as a substrate and product with the same stoichiometry. Please fix reaction: $rx"))
681+
if (rx isa Reaction) && !isbcbalanced(rx)
682+
throw(ErrorException("BC species must be balanced, appearing as a substrate and product with the same stoichiometry. Please fix reaction: $rx"))
680683
end
681684
end
682685
end
683686

687+
# Adds all unknowns/parameters to the `var_to_name` vector.
688+
# Adds their (potential) default values to the defaults vector.
684689
var_to_name = Dict()
685690
MT.process_variables!(var_to_name, defaults, unknowns′)
686691
MT.process_variables!(var_to_name, defaults, ps′)
687692
MT.collect_var_to_name!(var_to_name, eq.lhs for eq in observed)
688693

689-
nps = if networkproperties === nothing
690-
NetworkProperties{Int, get_speciestype(iv′, unknowns′, systems)}()
694+
# Computes network properties.
695+
if networkproperties === nothing
696+
nps = NetworkProperties{Int, get_speciestype(iv′, unknowns′, systems)}()
691697
else
692-
networkproperties
698+
nps = networkproperties
693699
end
694700

701+
# Creates the continious and discrete callbacks.
695702
ccallbacks = MT.SymbolicContinuousCallbacks(continuous_events)
696703
dcallbacks = MT.SymbolicDiscreteCallbacks(discrete_events)
697704

@@ -705,77 +712,125 @@ function ReactionSystem(rxs::Vector, iv = Catalyst.DEFAULT_IV; kwargs...)
705712
end
706713

707714
# search the symbolic expression for parameters or unknowns
708-
# and save in ps and sts respectively. vars is used to cache results
709-
function findvars!(ps, sts, exprtosearch, ivs, vars)
715+
# and save in ps and us respectively. vars is used to cache results
716+
function findvars!(ps, us, exprtosearch, ivs, vars)
710717
MT.get_variables!(vars, exprtosearch)
711718
for var in vars
712719
(var ivs) && continue
713720
if MT.isparameter(var)
714721
push!(ps, var)
715722
else
716-
push!(sts, var)
723+
push!(us, var)
717724
end
718725
end
719726
empty!(vars)
720727
end
721728

722-
# Only used internally by the @reaction_network macro. Permits giving an initial order to
723-
# the parameters, and then adds additional ones found in the reaction. Name could be
724-
# changed.
725-
function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, sts_in, ps_in;
726-
spatial_ivs = nothing, kwargs...)
729+
# Called internally (whether DSL-based or programmtic model creation is used).
730+
# Creates a sorted reactions + equations vector, also ensuring reaction is first in this vector.
731+
# Extracts potential species, variables, and parameters from the input (if not provided as part of
732+
# the model creation) and creates the corresponding vectors.
733+
# While species are ordered before variables in the unknowns vector, this ordering is not imposed here,
734+
# but carried out at a later stage.
735+
function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in; spatial_ivs = nothing,
736+
continuous_events = [], discrete_events = [], kwargs...)
737+
738+
# Creates a combined iv vector (iv and sivs). This is used later in the function (so that
739+
# independent variables can be exluded when encountered quantities are added to `us` and `ps`).
727740
t = value(iv)
728741
ivs = Set([t])
729742
if (spatial_ivs !== nothing)
730743
for siv in (MT.scalarize(spatial_ivs))
731744
push!(ivs, value(siv))
732745
end
733746
end
734-
sts = OrderedSet{eltype(sts_in)}(sts_in)
747+
748+
# Initialises the new unknowns and parameter vectors.
749+
# Preallocates the `vars` set, which is used by `findvars!`
750+
us = OrderedSet{eltype(us_in)}(us_in)
735751
ps = OrderedSet{eltype(ps_in)}(ps_in)
736752
vars = OrderedSet()
737753

754+
# Extracts the reactions and equations from the combined reactions + equations input vector.
738755
all(eq -> eq isa Union{Reaction, Equation}, rxs_and_eqs)
739756
rxs = Reaction[eq for eq in rxs_and_eqs if eq isa Reaction]
740757
eqs = Equation[eq for eq in rxs_and_eqs if eq isa Equation]
741758

742-
# add species / parameters that are substrates / products first
743-
for rx in rxs, reactants in (rx.substrates, rx.products)
744-
for spec in reactants
745-
MT.isparameter(spec) ? push!(ps, spec) : push!(sts, spec)
746-
end
747-
end
748-
759+
# Loops through all reactions, adding encountered quantities to the unknown and parameter vectors.
749760
for rx in rxs
750-
findvars!(ps, sts, rx.rate, ivs, vars)
751-
for s in rx.substoich
752-
(s isa Symbolic) && findvars!(ps, sts, s, ivs, vars)
761+
# Loops through all reaction substrates and products, extracting these.
762+
for reactants in (rx.substrates, rx.products), spec in reactants
763+
MT.isparameter(spec) ? push!(ps, spec) : push!(us, spec)
753764
end
754-
for p in rx.prodstoich
755-
(p isa Symbolic) && findvars!(ps, sts, p, ivs, vars)
765+
766+
# Adds all quantitites encountered in the reaction's rate.
767+
findvars!(ps, us, rx.rate, ivs, vars)
768+
769+
# Extracts all quantitites encountered within stoichiometries.
770+
for stoichiometry in (rx.substoich, rx.prodstoich), sym in stoichiometry
771+
(sym isa Symbolic) && findvars!(ps, us, sym, ivs, vars)
756772
end
757-
end
758773

759-
stsv = collect(sts)
760-
psv = collect(ps)
774+
# Will appear here: add stuff from nosie scaling.
775+
end
761776

777+
# Extracts any species, variables, and parameters that occur in (non-reaction) equations.
778+
# Creates the new reactions + equations vector, `fulleqs` (sorted reactions first, equations next).
762779
if !isempty(eqs)
763780
osys = ODESystem(eqs, iv; name = gensym())
764781
fulleqs = CatalystEqType[rxs; equations(osys)]
765-
union!(stsv, unknowns(osys))
766-
union!(psv, parameters(osys))
782+
union!(us, unknowns(osys))
783+
union!(ps, parameters(osys))
767784
else
768785
fulleqs = rxs
769-
end
786+
end
770787

771-
ReactionSystem(fulleqs, t, stsv, psv; spatial_ivs, kwargs...)
788+
# Loops through all events, adding encountered quantities to the unknwon and parameter vectors.
789+
find_event_vars!(ps, us, continuous_events, ivs, vars)
790+
find_event_vars!(ps, us, discrete_events, ivs, vars)
791+
792+
# Converts the found unknowns and parameters to vectors.
793+
usv = collect(us)
794+
psv = collect(ps)
795+
796+
# Passes the processed input into the next `ReactionSystem` call.
797+
ReactionSystem(fulleqs, t, usv, psv; spatial_ivs, continuous_events, discrete_events, kwargs...)
772798
end
773799

774800
function ReactionSystem(iv; kwargs...)
775801
ReactionSystem(Reaction[], iv, [], []; kwargs...)
776802
end
777803

778-
804+
# Loops through all events in an supplied event vector, adding all unknowns and parameters found in
805+
# its condition and affect functions to their respective vectors (`ps` and `us`).
806+
function find_event_vars!(ps, us, events::Vector, ivs, vars)
807+
foreach(event -> find_event_vars!(ps, us, event, ivs, vars), events)
808+
end
809+
# For a single event, adds quantitites from its condition and affect expression(s) to `ps` and `us`.
810+
function find_event_vars!(ps, us, event, ivs, vars)
811+
conds, affects = event
812+
# For discrete events, the condition can be a single value (for periodic events).
813+
# If not, it is a vector of conditions and we must check each.
814+
if conds isa Vector
815+
for cond in conds
816+
# For continious events the conditions are equations (with lhs and rhs).
817+
# For discrete events, they are single expressions.
818+
if cond isa Equation
819+
findvars!(ps, us, cond.lhs, ivs, vars)
820+
findvars!(ps, us, cond.rhs, ivs, vars)
821+
else
822+
findvars!(ps, us, cond, ivs, vars)
823+
end
824+
end
825+
else
826+
findvars!(ps, us, conds, ivs, vars)
827+
end
828+
# The affects is always a vector of equations. Here, we handle the lhs and rhs separately.
829+
for affect in affects
830+
findvars!(ps, us, affect.lhs, ivs, vars)
831+
findvars!(ps, us, affect.rhs, ivs, vars)
832+
end
833+
end
779834
"""
780835
remake_ReactionSystem_internal(rs::ReactionSystem;
781836
default_reaction_metadata::Vector{Pair{Symbol, T}} = Vector{Pair{Symbol, Any}}()) where {T}

test/dsl/dsl_options.jl

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -860,93 +860,4 @@ let
860860
@equations X ~ p - S
861861
(P,D), 0 <--> S
862862
end
863-
end
864-
865-
### Events ###
866-
867-
# Compares models with complicated events that are created programmatically/with the DSL.
868-
# Checks that simulations are correct.
869-
# Checks that various simulation inputs works.
870-
# Checks continuous, discrete, preset time, and periodic events.
871-
# Tests event affecting non-species components.
872-
873-
let
874-
# Creates model via DSL.
875-
rn_dsl = @reaction_network rn begin
876-
@parameters thres=1.0 dY_up
877-
@variables Z(t)
878-
@continuous_events begin
879-
[t - 2.5] => [p ~ p + 0.2]
880-
[X - thres, Y - X] => [X ~ X - 0.5, Z ~ Z + 0.1]
881-
end
882-
@discrete_events begin
883-
2.0 => [dX ~ dX + 0.1, dY ~ dY + dY_up]
884-
[1.0, 5.0] => [p ~ p - 0.1]
885-
(Z > Y) => [Z ~ Z - 0.1]
886-
end
887-
888-
(p, dX), 0 <--> X
889-
(p, dY), 0 <--> Y
890-
end
891-
892-
# Creates model programmatically.
893-
@variables t Z(t)
894-
@species X(t) Y(t)
895-
@parameters p dX dY thres=1.0 dY_up
896-
rxs = [
897-
Reaction(p, nothing, [X], nothing, [1])
898-
Reaction(dX, [X], nothing, [1], nothing)
899-
Reaction(p, nothing, [Y], nothing, [1])
900-
Reaction(dY, [Y], nothing, [1], nothing)
901-
]
902-
continuous_events = [
903-
t - 2.5 => p ~ p + 0.2
904-
[X - thres, Y - X] => [X ~ X - 0.5, Z ~ Z + 0.1]
905-
]
906-
discrete_events = [
907-
2.0 => [dX ~ dX + 0.1, dY ~ dY + dY_up]
908-
[1.0, 5.0] => [p ~ p - 0.1]
909-
(Z > Y) => [Z ~ Z - 0.1]
910-
]
911-
rn_prog = ReactionSystem(rxs, t; continuous_events, discrete_events, name=:rn)
912-
913-
# Tests that approaches yield identical results.
914-
@test isequal(rn_dsl, rn_prog)
915-
916-
u0 = [X => 1.0, Y => 0.5, Z => 0.25]
917-
tspan = (0.0, 20.0)
918-
ps = [p => 1.0, dX => 0.5, dY => 0.5, dY_up => 0.1]
919-
920-
sol_dsl = solve(ODEProblem(rn_dsl, u0, tspan, ps), Tsit5())
921-
sol_prog = solve(ODEProblem(rn_prog, u0, tspan, ps), Tsit5())
922-
@test sol_dsl == sol_prog
923-
end
924-
925-
# Compares DLS events to those given as callbacks.
926-
# Checks that events works when given to SDEs.
927-
let
928-
# Creates models.
929-
rn = @reaction_network begin
930-
(p, d), 0 <--> X
931-
end
932-
rn_events = @reaction_network begin
933-
@discrete_events begin
934-
[5.0, 10.0] => [X ~ X + 100.0]
935-
end
936-
@continuous_events begin
937-
[X ~ 90.0] => [X ~ X + 10.0]
938-
end
939-
(p, d), 0 <--> X
940-
end
941-
cb_disc = ModelingToolkit.PresetTimeCallback([5.0, 10.0], int -> (int[:X] += 100.0))
942-
cb_cont = ContinuousCallback((u, t, int) -> (u[1] - 90.0), int -> (int[:X] += 10.0))
943-
944-
# Simulates models.
945-
u0 = [:X => 100.0]
946-
tspan = (0.0, 50.0)
947-
ps = [:p => 100.0, :d => 1.0]
948-
sol = solve(SDEProblem(rn, u0, tspan, ps), ImplicitEM(); seed, callback = CallbackSet(cb_disc, cb_cont))
949-
sol_events = solve(SDEProblem(rn_events, u0, tspan, ps), ImplicitEM(); seed)
950-
951-
@test sol == sol_events
952863
end

0 commit comments

Comments
 (0)