Skip to content

Commit 1e743a3

Browse files
committed
construct pdmps
1 parent 98cc3ed commit 1e743a3

File tree

1 file changed

+41
-18
lines changed

1 file changed

+41
-18
lines changed

src/reactionsystem_conversions.jl

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ end
4646
# including non-species variables.
4747
drop_dynamics(s) = isconstant(s) || isbc(s) || (!isspecies(s))
4848

49-
function assemble_oderhs(rs, ispcs; combinatoric_ratelaws = true, remove_conserved = false)
49+
function assemble_oderhs(rs, ispcs; combinatoric_ratelaws = true, remove_conserved = false,
50+
physical_scales = nothing)
5051
nps = get_networkproperties(rs)
5152
species_to_idx = Dict(x => i for (i, x) in enumerate(ispcs))
5253
rhsvec = Any[0 for _ in ispcs]
@@ -56,7 +57,11 @@ function assemble_oderhs(rs, ispcs; combinatoric_ratelaws = true, remove_conserv
5657
Dict()
5758
end
5859

59-
for rx in get_rxs(rs)
60+
for (rxidx,rx) in enumerate(get_rxs(rs))
61+
# check this reaction should be treated as an ODE
62+
!((physical_scales === nothing) ||
63+
(physical_scales[rxidx] == PhysicalScale.ODE)) && continue
64+
6065
rl = oderatelaw(rx; combinatoric_ratelaw = combinatoric_ratelaws)
6166
remove_conserved && (rl = substitute(rl, depspec_submap))
6267
for (spec, stoich) in rx.netstoich
@@ -90,8 +95,10 @@ function assemble_oderhs(rs, ispcs; combinatoric_ratelaws = true, remove_conserv
9095
end
9196

9297
function assemble_drift(rs, ispcs; combinatoric_ratelaws = true, as_odes = true,
93-
include_zero_odes = true, remove_conserved = false)
94-
rhsvec = assemble_oderhs(rs, ispcs; combinatoric_ratelaws, remove_conserved)
98+
include_zero_odes = true, remove_conserved = false, physical_scales = nothing)
99+
100+
rhsvec = assemble_oderhs(rs, ispcs; combinatoric_ratelaws, remove_conserved,
101+
physical_scales)
95102
if as_odes
96103
D = Differential(get_iv(rs))
97104
eqs = [Equation(D(x), rhs)
@@ -371,6 +378,9 @@ function assemble_jumps(rs; combinatoric_ratelaws = true, physical_scales = noth
371378

372379
rxvars = []
373380
for (i, rx) in enumerate(rxs)
381+
# only process reactions that should give jumps
382+
(physcales[i] in jump_scales) || continue
383+
374384
empty!(rxvars)
375385
(rx.rate isa Symbolic) && get_variables!(rxvars, rx.rate)
376386

@@ -384,7 +394,7 @@ function assemble_jumps(rs; combinatoric_ratelaws = true, physical_scales = noth
384394
# don't change species that are constant or BCs
385395
(!drop_dynamics(spec)) && push!(affect, spec ~ spec + stoich)
386396
end
387-
if isvrj
397+
if isvrj
388398
push!(veqs, VariableRateJump(rl, affect))
389399
else
390400
push!(ceqs, ConstantRateJump(rl, affect))
@@ -661,7 +671,7 @@ function Base.convert(::Type{<:SDESystem}, rs::ReactionSystem;
661671
SDESystem(eqs, noiseeqs, get_iv(flatrs), us, ps;
662672
observed = obs,
663673
name,
664-
defaults = defs,
674+
defaults = _merge(defaults, defs),
665675
checks,
666676
continuous_events = MT.get_continuous_events(flatrs),
667677
discrete_events = MT.get_discrete_events(flatrs),
@@ -721,38 +731,51 @@ function Base.convert(::Type{<:JumpSystem}, rs::ReactionSystem; name = nameof(rs
721731
combinatoric_ratelaws = get_combinatoric_ratelaws(rs),
722732
remove_conserved = nothing, checks = false,
723733
default_u0 = Dict(), default_p = Dict(),
724-
defaults = _merge(Dict(default_u0), Dict(default_p)), physical_scales = nothing,
734+
defaults = _merge(Dict(default_u0), Dict(default_p)), include_zero_odes = true,
735+
physical_scales = nothing,
725736
kwargs...)
726737
iscomplete(rs) || error(COMPLETENESS_ERROR)
727738
spatial_convert_err(rs::ReactionSystem, JumpSystem)
728739
(remove_conserved !== nothing) &&
729740
throw(ArgumentError("Catalyst does not support removing conserved species when converting to JumpSystems."))
730741

731742
flatrs = Catalyst.flatten(rs)
732-
error_if_constraints(JumpSystem, flatrs)
733743

734-
physical_scales = merge_physical_scales(reactions(rs), physical_scales,
744+
physical_scales = merge_physical_scales(reactions(flatrs), physical_scales,
735745
PhysicalScale.Jump)
736746
admissible_scales = (PhysicalScale.ODE, PhysicalScale.Jump,
737747
PhysicalScale.VariableRateJump)
738748
unique_scales = unique(physical_scales)
739749
(unique_scales admissible_scales) ||
740750
error("Physical scales must currently be one of $admissible_scales for hybrid systems.")
741-
hasodes = (PhysicalScale.ODE in unique_scales) || has_nonreactions(flatrs)
742751

752+
# basic jump states and equations
743753
eqs = assemble_jumps(flatrs; combinatoric_ratelaws, physical_scales)
754+
ists, ispcs = get_indep_sts(flatrs)
755+
756+
# handle coupled ODEs and BC species
757+
if (PhysicalScale.ODE in unique_scales) || has_nonreactions(flatrs)
758+
odeeqs = assemble_drift(flatrs, ispcs; combinatoric_ratelaws,
759+
remove_conserved = false, include_zero_odes, physical_scales)
760+
append!(eqs, odeeqs)
761+
eqs, us, ps, obs, defs = addconstraints!(eqs, flatrs, ists, ispcs;
762+
remove_conserved = false)
763+
else
764+
any(isbc, get_unknowns(flatrs)) &&
765+
(ists = vcat(ists, filter(isbc, get_unknowns(flatrs))))
766+
us = ists
767+
ps = get_ps(flatrs)
768+
obs = MT.observed(flatrs)
769+
defs = MT.defaults(flatrs)
770+
end
744771

745-
# handle BC species
746-
sts, ispcs = get_indep_sts(flatrs)
747-
any(isbc, get_unknowns(flatrs)) && (sts = vcat(sts, filter(isbc, get_unknowns(flatrs))))
748-
ps = get_ps(flatrs)
749-
750-
JumpSystem(eqs, get_iv(flatrs), sts, ps;
751-
observed = MT.observed(flatrs),
772+
JumpSystem(eqs, get_iv(flatrs), us, ps;
773+
observed = obs,
752774
name,
753-
defaults = _merge(defaults, MT.defaults(flatrs)),
775+
defaults = _merge(defaults, defs),
754776
checks,
755777
discrete_events = MT.discrete_events(flatrs),
778+
continuous_events = MT.continuous_events(flatrs),
756779
kwargs...)
757780
end
758781

0 commit comments

Comments
 (0)