Skip to content

Commit 0d8dda1

Browse files
committed
updates
1 parent 98decab commit 0d8dda1

File tree

1 file changed

+51
-30
lines changed

1 file changed

+51
-30
lines changed

src/reactionsystem_conversions.jl

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ Notes:
2222
`combinatoric_ratelaw=false` then the ratelaw is `k*S^2`, i.e. the scaling
2323
factor is ignored.
2424
"""
25-
function oderatelaw(rx; combinatoric_ratelaw = true)
25+
function oderatelaw(rx; combinatoric_ratelaw = true, expand_catalyst_funs = true)
2626
@unpack rate, substrates, substoich, only_use_rate = rx
2727
rl = rate
28+
expand_catalyst_funs && (rl = expand_registered_functions(rl))
2829

2930
# if the stoichiometric coefficients are not integers error if asking to scale rates
3031
!all(s -> s isa Union{Integer, Symbolic}, substoich) &&
@@ -47,7 +48,7 @@ end
4748
drop_dynamics(s) = isconstant(s) || isbc(s) || (!isspecies(s))
4849

4950
function assemble_oderhs(rs, ispcs; combinatoric_ratelaws = true, remove_conserved = false,
50-
physical_scales = nothing)
51+
physical_scales = nothing, expand_catalyst_funs = true)
5152
nps = get_networkproperties(rs)
5253
species_to_idx = Dict(x => i for (i, x) in enumerate(ispcs))
5354
rhsvec = Any[0 for _ in ispcs]
@@ -62,7 +63,8 @@ function assemble_oderhs(rs, ispcs; combinatoric_ratelaws = true, remove_conserv
6263
!((physical_scales === nothing) ||
6364
(physical_scales[rxidx] == PhysicalScale.ODE)) && continue
6465

65-
rl = oderatelaw(rx; combinatoric_ratelaw = combinatoric_ratelaws)
66+
rl = oderatelaw(rx; combinatoric_ratelaw = combinatoric_ratelaws,
67+
expand_catalyst_funs)
6668
remove_conserved && (rl = substitute(rl, depspec_submap))
6769
for (spec, stoich) in rx.netstoich
6870
# dependent species don't get an ODE, so are skipped
@@ -95,10 +97,10 @@ function assemble_oderhs(rs, ispcs; combinatoric_ratelaws = true, remove_conserv
9597
end
9698

9799
function assemble_drift(rs, ispcs; combinatoric_ratelaws = true, as_odes = true,
98-
include_zero_odes = true, remove_conserved = false, physical_scales = nothing)
99-
100+
include_zero_odes = true, remove_conserved = false, physical_scales = nothing,
101+
expand_catalyst_funs = true)
100102
rhsvec = assemble_oderhs(rs, ispcs; combinatoric_ratelaws, remove_conserved,
101-
physical_scales)
103+
physical_scales, expand_catalyst_funs)
102104
if as_odes
103105
D = Differential(get_iv(rs))
104106
eqs = [Equation(D(x), rhs)
@@ -111,7 +113,7 @@ end
111113

112114
# this doesn't work with constraint equations currently
113115
function assemble_diffusion(rs, sts, ispcs; combinatoric_ratelaws = true,
114-
remove_conserved = false)
116+
remove_conserved = falsem , expand_catalyst_funs = true)
115117
# as BC species should ultimately get an equation, we include them in the noise matrix
116118
num_bcsts = count(isbc, get_unknowns(rs))
117119

@@ -127,7 +129,9 @@ function assemble_diffusion(rs, sts, ispcs; combinatoric_ratelaws = true,
127129
end
128130

129131
for (j, rx) in enumerate(get_rxs(rs))
130-
rlsqrt = sqrt(abs(oderatelaw(rx; combinatoric_ratelaw = combinatoric_ratelaws)))
132+
rl = oderatelaw(rx; combinatoric_ratelaw = combinatoric_ratelaws,
133+
expand_catalyst_funs)
134+
rlsqrt = sqrt(abs(rl))
131135
hasnoisescaling(rx) && (rlsqrt *= getnoisescaling(rx))
132136
remove_conserved && (rlsqrt = substitute(rlsqrt, depspec_submap))
133137

@@ -176,9 +180,12 @@ Notes:
176180
the ratelaw is `k*S*(S-1)`, i.e. the rate law is not normalized by the scaling
177181
factor.
178182
"""
179-
function jumpratelaw(rx; combinatoric_ratelaw = true)
183+
function jumpratelaw(rx; combinatoric_ratelaw = true, expand_catalyst_funs = true)
180184
@unpack rate, substrates, substoich, only_use_rate = rx
185+
181186
rl = rate
187+
expand_catalyst_funs && (rl = expand_registered_functions(rl))
188+
182189
if !only_use_rate
183190
coef = eltype(substoich) <: Number ? one(eltype(substoich)) : 1
184191
for (i, stoich) in enumerate(substoich)
@@ -360,7 +367,8 @@ function classify_vrjs(rs, physcales)
360367
isvrjvec
361368
end
362369

363-
function assemble_jumps(rs; combinatoric_ratelaws = true, physical_scales = nothing)
370+
function assemble_jumps(rs; combinatoric_ratelaws = true, physical_scales = nothing,
371+
expand_catalyst_funs = true)
364372
meqs = MassActionJump[]
365373
ceqs = ConstantRateJump[]
366374
veqs = VariableRateJump[]
@@ -389,7 +397,8 @@ function assemble_jumps(rs; combinatoric_ratelaws = true, physical_scales = noth
389397
if (!isvrj) && ismassaction(rx, rs; rxvars, haveivdep = false, unknownset)
390398
push!(meqs, makemajump(rx; combinatoric_ratelaw = combinatoric_ratelaws))
391399
else
392-
rl = jumpratelaw(rx; combinatoric_ratelaw = combinatoric_ratelaws)
400+
rl = jumpratelaw(rx; combinatoric_ratelaw = combinatoric_ratelaws,
401+
expand_catalyst_funs)
393402
affect = Vector{Equation}()
394403
for (spec, stoich) in rx.netstoich
395404
# don't change species that are constant or BCs
@@ -409,46 +418,58 @@ end
409418

410419
# merge constraint components with the ReactionSystem components
411420
# also handles removing BC and constant species
412-
function addconstraints!(eqs, rs::ReactionSystem, ists, ispcs; remove_conserved = false)
421+
function addconstraints!(eqs, rs::ReactionSystem, ists, ispcs; remove_conserved = false,
422+
treat_conserved_as_eqs = false)
413423
# if there are BC species, put them after the independent species
414424
rssts = get_unknowns(rs)
415425
sts = any(isbc, rssts) ? vcat(ists, filter(isbc, rssts)) : ists
416426
ps = get_ps(rs)
417-
427+
initeqs = Equation[]
428+
defs = MT.defaults(rs)
429+
obs = MT.observed(rs)
430+
418431
# make dependent species observables and add conservation constants as parameters
419432
if remove_conserved
420433
nps = get_networkproperties(rs)
421434

422435
# add the conservation constants as parameters and set their values
423-
ps = vcat(ps, collect(eq.lhs for eq in nps.constantdefs))
424-
defs = copy(MT.defaults(rs))
425-
for eq in nps.constantdefs
426-
defs[eq.lhs] = eq.rhs
427-
end
436+
ps = copy(ps)
437+
push!(ps, nps.conservedconst)
428438

429-
# add the dependent species as observed
430-
obs = copy(MT.observed(rs))
431-
append!(obs, nps.conservedeqs)
432-
else
433-
defs = MT.defaults(rs)
434-
obs = MT.observed(rs)
439+
if treat_conserved_as_eqs
440+
# add back previously removed dependent species
441+
sts = union(sts, nps.depspecs)
442+
443+
# treat conserved eqs as normal eqs
444+
append!(eqs, conservedequations(rs))
445+
446+
# add initialization equations for conserved parameters
447+
initialmap = Dict(u => Initial(u) for u in species(rs))
448+
conseqs = conservationlaw_constants(rs)
449+
initeqs = [Symbolics.substitute(conseq, initialmap) for conseq in conseqs]
450+
else
451+
# add the dependent species as observed
452+
obs = copy(obs)
453+
append!(obs, conservedequations(rs))
454+
end
435455
end
436456

437457
ceqs = Equation[eq for eq in get_eqs(rs) if eq isa Equation]
438458
if !isempty(ceqs)
439459
if remove_conserved
440460
@info """
441-
Be careful mixing constraints and elimination of conservation laws.
442-
Catalyst does not check that the conserved equations still hold for the
443-
final coupled system of equations. Consider using `remove_conserved =
444-
false` and instead calling ModelingToolkit.structural_simplify to simplify
445-
any generated ODESystem or NonlinearSystem.
461+
Be careful mixing ODEs or algebraic equations and elimination of
462+
conservation laws. Catalyst does not check that the conserved equations
463+
still hold for the final coupled system of equations. Consider using
464+
`remove_conserved = false` and instead calling
465+
ModelingToolkit.structural_simplify to simplify any generated ODESystem or
466+
NonlinearSystem.
446467
"""
447468
end
448469
append!(eqs, ceqs)
449470
end
450471

451-
eqs, sts, ps, obs, defs
472+
eqs, sts, ps, obs, defs, initeqs
452473
end
453474

454475
# used by flattened systems that don't support constraint equations currently

0 commit comments

Comments
 (0)