Skip to content

Commit 5924139

Browse files
committed
refactor
1 parent 27c8856 commit 5924139

File tree

2 files changed

+76
-20
lines changed

2 files changed

+76
-20
lines changed

src/reactionsystem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,16 @@ function numreactions(network)
823823
nr
824824
end
825825

826+
"""
827+
has_nonreactions(network)
828+
829+
Check if the given `network` has any non-reaction equations such as ODEs or algebraic
830+
equations.
831+
"""
832+
function has_nonreactions(network)
833+
numreactions(network) != length(equations(network))
834+
end
835+
826836
"""
827837
nonreactions(network)
828838

src/reactionsystem_conversions.jl

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,12 @@ function ismassaction(rx, rs; rxvars = get_variables(rx.rate),
242242
if (ivset === nothing)
243243
ivs = Set(get_sivs(rs))
244244
push!(ivs, get_iv(rs))
245-
ivdep = any(var -> var ivs, rxvars)
245+
ivdep = any(in(ivs), rxvars)
246246
else
247-
ivdep = any(var -> var ivset, rxvars)
247+
ivdep = any(in(ivset), rxvars)
248248
end
249249
else
250-
ivdep = any(var -> isequal(get_iv(rs), var), rxvars)
250+
ivdep = any(isequal(get_iv(rs)), rxvars)
251251
end
252252
ivdep && return false
253253
else
@@ -313,23 +313,19 @@ function get_depgraph(rs)
313313
eqeq_dependencies(jdeps, vdeps).fadjlist
314314
end
315315

316-
function assemble_jumps(rs; combinatoric_ratelaws = true)
317-
meqs = MassActionJump[]
318-
ceqs = ConstantRateJump[]
319-
veqs = VariableRateJump[]
320-
unknownset = Set(get_unknowns(rs))
321-
all(isspecies, unknownset) ||
322-
error("Conversion to JumpSystem currently requires all unknowns to be species.")
323-
rxvars = []
324-
325-
isempty(get_rxs(rs)) &&
326-
error("Must give at least one reaction before constructing a JumpSystem.")
327-
316+
function classify_vrjs(rs, physcales)
328317
# first we determine vrjs with an explicit time-dependent rate
329318
rxs = get_rxs(rs)
330319
isvrjvec = falses(length(rxs))
331320
havevrjs = false
321+
rxvars = Set()
332322
for (i, rx) in enumerate(rxs)
323+
if physcales[i] == PhysicalScale.VariableRateJump
324+
isvrjvec[i] = true
325+
havevrjs = true
326+
continue
327+
end
328+
333329
empty!(rxvars)
334330
(rx.rate isa Symbolic) && get_variables!(rxvars, rx.rate)
335331
@inbounds for rxvar in rxvars
@@ -353,6 +349,27 @@ function assemble_jumps(rs; combinatoric_ratelaws = true)
353349
end
354350
end
355351

352+
isvrjvec
353+
end
354+
355+
function assemble_jumps(rs; combinatoric_ratelaws = true, physical_scales = nothing)
356+
meqs = MassActionJump[]
357+
ceqs = ConstantRateJump[]
358+
veqs = VariableRateJump[]
359+
unknownset = Set(get_unknowns(rs))
360+
rxs = get_rxs(rs)
361+
362+
if physical_scales === nothing
363+
physcales = [PhysicalScale.Jump for _ in enumerate(rxs)]
364+
else
365+
physcales = physical_scales
366+
end
367+
jump_scales = (PhysicalScale.Jump, PhysicalScale.VariableRateJump)
368+
(isempty(get_rxs(rs)) || !any(in(jump_scales), physcales)) &&
369+
error("Must have at least one reaction that will be represented as a jump when constructing a JumpSystem.")
370+
isvrjvec = classify_vrjs(rs, physcales)
371+
372+
rxvars = []
356373
for (i, rx) in enumerate(rxs)
357374
empty!(rxvars)
358375
(rx.rate isa Symbolic) && get_variables!(rxvars, rx.rate)
@@ -651,13 +668,34 @@ function Base.convert(::Type{<:SDESystem}, rs::ReactionSystem;
651668
kwargs...)
652669
end
653670

654-
function merge_physical_scales(rxs, physical_scales)
671+
"""
672+
merge_physical_scales(rxs, physical_scales; default = PhysicalScale.Auto)
673+
674+
Merge physical scales for a set of reactions.
675+
676+
# Arguments
677+
- `rxs`, a vector of `Reaction`s.
678+
- `physical_scales`, an iterable of pairs mapping integer reaction indices to
679+
`PhysicalScale`s.
680+
- `default`, the default physical scale to use for reactions that set PhysicalScale.Auto.
681+
"""
682+
function merge_physical_scales(rxs, physical_scales, default)
655683
scales = get_physical_scale.(rxs)
684+
685+
# override metadata attached scales
656686
if physical_scales !== nothing
657-
for (idx, scale) in physical_scales
658-
scales[idx] = scale
687+
for (key, scale) in physical_scales
688+
scales[key] = scale
659689
end
660690
end
691+
692+
# transform any "Auto" scales to the default
693+
for (idx, scale) in enumerate(scales)
694+
if scale == PhysicalScale.Auto
695+
scales[idx] = default
696+
end
697+
end
698+
661699
scales
662700
end
663701

@@ -693,8 +731,16 @@ function Base.convert(::Type{<:JumpSystem}, rs::ReactionSystem; name = nameof(rs
693731
flatrs = Catalyst.flatten(rs)
694732
error_if_constraints(JumpSystem, flatrs)
695733

696-
physical_scales = merge_physical_scales(reactions(rs), physical_scales)
697-
eqs = assemble_jumps(flatrs; combinatoric_ratelaws)
734+
physical_scales = merge_physical_scales(reactions(rs), physical_scales,
735+
PhysicalScale.Jump)
736+
admissible_scales = (PhysicalScale.ODE, PhysicalScale.Jump,
737+
PhysicalScale.VariableRateJump)
738+
unique_scales = unique(physical_scales)
739+
(unique_scales admissible_scales) ||
740+
error("Physical scales must currently be one of $admissible_scales for hybrid systems.")
741+
hasodes = (PhysicalScale.ODE in unique_scales) || has_nonreactions(flatrs)
742+
743+
eqs = assemble_jumps(flatrs; combinatoric_ratelaws, physical_scales)
698744

699745
# handle BC species
700746
sts, ispcs = get_indep_sts(flatrs)

0 commit comments

Comments
 (0)