Skip to content

Commit 935a32e

Browse files
authored
Merge pull request #920 from isaacsas/composing-reactionsystems
Composing `ReactionSystem`s
2 parents 7f2c743 + 3db777f commit 935a32e

File tree

4 files changed

+156
-41
lines changed

4 files changed

+156
-41
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
*.jl.*.cov
33
*.jl.mem
44
Manifest.toml
5+
.vscode
6+
.vscode/*

src/systems/reaction/reactionsystem.jl

Lines changed: 67 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,14 @@ function Reaction(rate, subs, prods; kwargs...)
9292
Reaction(rate, subs, prods, sstoich, pstoich; kwargs...)
9393
end
9494

95+
function namespace_equation(rx::Reaction, name, iv)
96+
Reaction(namespace_expr(rx.rate, name, iv),
97+
namespace_expr(rx.substrates, name, iv),
98+
namespace_expr(rx.products, name, iv),
99+
rx.substoich, rx.prodstoich,
100+
[namespace_expr(n[1],name,iv) => n[2] for n in rx.netstoich], rx.only_use_rate)
101+
end
102+
95103
# calculates the net stoichiometry of a reaction as a vector of pairs (sub,substoich)
96104
function get_netstoich(subs, prods, sstoich, pstoich)
97105
# stoichiometry as a Dictionary
@@ -134,7 +142,7 @@ struct ReactionSystem <: AbstractSystem
134142
"""The name of the system"""
135143
name::Symbol
136144
"""systems: The internal systems"""
137-
systems::Vector{ReactionSystem}
145+
systems::Vector
138146

139147
function ReactionSystem(eqs, iv, states, ps, observed, name, systems)
140148
new(eqs, value(iv), value.(states), value.(ps), observed, name, systems)
@@ -143,13 +151,31 @@ end
143151

144152
function ReactionSystem(eqs, iv, species, params;
145153
observed = [],
146-
systems = ReactionSystem[],
154+
systems = [],
147155
name = gensym(:ReactionSystem))
148156

149-
isempty(species) && error("ReactionSystems require at least one species.")
157+
#isempty(species) && error("ReactionSystems require at least one species.")
150158
ReactionSystem(eqs, iv, species, params, observed, name, systems)
151159
end
152160

161+
function ReactionSystem(iv; kwargs...)
162+
ReactionSystem(Reaction[], iv, [], []; kwargs...)
163+
end
164+
165+
function equations(sys::ModelingToolkit.ReactionSystem)
166+
eqs = get_eqs(sys)
167+
systems = get_systems(sys)
168+
if isempty(systems)
169+
return eqs
170+
else
171+
eqs = [eqs;
172+
reduce(vcat,
173+
namespace_equations.(get_systems(sys));
174+
init=[])]
175+
return eqs
176+
end
177+
end
178+
153179
"""
154180
oderatelaw(rx; combinatoric_ratelaw=true)
155181
@@ -187,11 +213,11 @@ function oderatelaw(rx; combinatoric_ratelaw=true)
187213
end
188214

189215
function assemble_oderhs(rs; combinatoric_ratelaws=true)
190-
sts = states(rs)
216+
sts = get_states(rs)
191217
species_to_idx = Dict((x => i for (i,x) in enumerate(sts)))
192218
rhsvec = Any[0 for i in eachindex(sts)]
193219

194-
for rx in equations(rs)
220+
for rx in get_eqs(rs)
195221
rl = oderatelaw(rx; combinatoric_ratelaw=combinatoric_ratelaws)
196222
for (spec,stoich) in rx.netstoich
197223
i = species_to_idx[spec]
@@ -212,16 +238,16 @@ function assemble_drift(rs; combinatoric_ratelaws=true, as_odes=true)
212238
rhsvec = assemble_oderhs(rs; combinatoric_ratelaws=combinatoric_ratelaws)
213239
if as_odes
214240
D = Differential(get_iv(rs))
215-
eqs = [Equation(D(x),rhs) for (x,rhs) in zip(states(rs),rhsvec)]
241+
eqs = [Equation(D(x),rhs) for (x,rhs) in zip(get_states(rs),rhsvec) if (!_iszero(rhs))]
216242
else
217-
eqs = [Equation(0,rhs) for rhs in rhsvec]
243+
eqs = [Equation(0,rhs) for rhs in rhsvec if (!_iszero(rhs))]
218244
end
219245
eqs
220246
end
221247

222248
function assemble_diffusion(rs, noise_scaling; combinatoric_ratelaws=true)
223-
sts = states(rs)
224-
eqs = Matrix{Any}(undef, length(sts), length(equations(rs)))
249+
sts = get_states(rs)
250+
eqs = Matrix{Any}(undef, length(sts), length(get_eqs(rs)))
225251
eqs .= 0
226252
species_to_idx = Dict((x => i for (i,x) in enumerate(sts)))
227253

@@ -302,7 +328,7 @@ explicitly on the independent variable (usually time).
302328
"""
303329
function ismassaction(rx, rs; rxvars = get_variables(rx.rate),
304330
haveivdep = any(var -> isequal(get_iv(rs),var), rxvars),
305-
stateset = Set(states(rs)))
331+
stateset = Set(get_states(rs)))
306332
# if no dependencies must be zero order
307333
(length(rxvars)==0) && return true
308334
haveivdep && return false
@@ -331,7 +357,7 @@ end
331357

332358
function assemble_jumps(rs; combinatoric_ratelaws=true)
333359
meqs = MassActionJump[]; ceqs = ConstantRateJump[]; veqs = VariableRateJump[]
334-
stateset = Set(states(rs))
360+
stateset = Set(get_states(rs))
335361
#rates = []; rstoich = []; nstoich = []
336362
rxvars = []
337363
ivname = nameof(get_iv(rs))
@@ -378,10 +404,11 @@ law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
378404
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
379405
ignored.
380406
"""
381-
function Base.convert(::Type{<:ODESystem}, rs::ReactionSystem; combinatoric_ratelaws=true)
382-
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws)
383-
ODESystem(eqs,get_iv(rs),states(rs),get_ps(rs),name=nameof(rs),
384-
systems=convert.(ODESystem,get_systems(rs)))
407+
function Base.convert(::Type{<:ODESystem}, rs::ReactionSystem;
408+
name=nameof(rs), combinatoric_ratelaws=true, kwargs...)
409+
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws)
410+
systems = map(sys -> (sys isa ODESystem) ? sys : convert(ODESystem, sys), get_systems(rs))
411+
ODESystem(eqs, get_iv(rs), get_states(rs), get_ps(rs), name=name, systems=systems)
385412
end
386413

387414
"""
@@ -397,9 +424,11 @@ law, i.e. for `2S -> 0` at rate `k` the ratelaw would be `k*S^2/2!`. If
397424
`combinatoric_ratelaws=false` then the ratelaw is `k*S^2`, i.e. the scaling factor is
398425
ignored.
399426
"""
400-
function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem; combinatoric_ratelaws=true)
401-
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, as_odes=false)
402-
NonlinearSystem(eqs,states(rs),get_ps(rs),name=nameof(rs),systems=convert.(NonlinearSystem,get_systems(rs)))
427+
function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem;
428+
name=nameof(rs), combinatoric_ratelaws=true, kwargs...)
429+
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, as_odes=false)
430+
systems = convert.(NonlinearSystem, get_systems(rs))
431+
NonlinearSystem(eqs, get_states(rs), get_ps(rs), name=name, systems=systems)
403432
end
404433

405434
"""
@@ -423,7 +452,8 @@ Finally, a `Vector{Operation}` can be provided (the length must be equal to the
423452
Here the noise for each reaction is scaled by the corresponding parameter in the input vector.
424453
This input may contain repeat parameters.
425454
"""
426-
function Base.convert(::Type{<:SDESystem},rs::ReactionSystem, combinatoric_ratelaws=true; noise_scaling=nothing)
455+
function Base.convert(::Type{<:SDESystem}, rs::ReactionSystem;
456+
noise_scaling=nothing, name=nameof(rs), combinatoric_ratelaws=true, kwargs...)
427457

428458
if noise_scaling isa Vector
429459
(length(noise_scaling)!=length(equations(rs))) &&
@@ -434,19 +464,14 @@ function Base.convert(::Type{<:SDESystem},rs::ReactionSystem, combinatoric_ratel
434464
noise_scaling = fill(value(noise_scaling),length(equations(rs)))
435465
end
436466

437-
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws)
438-
467+
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws)
439468
noiseeqs = assemble_diffusion(rs,noise_scaling;
440469
combinatoric_ratelaws=combinatoric_ratelaws)
441-
442-
SDESystem(eqs,
443-
noiseeqs,
444-
get_iv(rs),
445-
states(rs),
446-
(noise_scaling===nothing) ?
447-
get_ps(rs) :
448-
union(get_ps(rs),toparam.(noise_scaling)),
449-
name=nameof(rs),systems=convert.(SDESystem,get_systems(rs)))
470+
systems = convert.(SDESystem, get_systems(rs))
471+
SDESystem(eqs, noiseeqs, get_iv(rs), get_states(rs),
472+
(noise_scaling===nothing) ? get_ps(rs) : union(get_ps(rs), toparam.(noise_scaling)),
473+
name=name,
474+
systems=systems)
450475
end
451476

452477
"""
@@ -462,10 +487,11 @@ Notes:
462487
the ratelaw is `k*S*(S-1)`, i.e. the rate law is not normalized by the scaling
463488
factor.
464489
"""
465-
function Base.convert(::Type{<:JumpSystem},rs::ReactionSystem; combinatoric_ratelaws=true)
466-
eqs = assemble_jumps(rs; combinatoric_ratelaws=combinatoric_ratelaws)
467-
JumpSystem(eqs,get_iv(rs),states(rs),get_ps(rs),name=nameof(rs),
468-
systems=convert.(JumpSystem,get_systems(rs)))
490+
function Base.convert(::Type{<:JumpSystem},rs::ReactionSystem;
491+
name=nameof(rs), combinatoric_ratelaws=true, kwargs...)
492+
eqs = assemble_jumps(rs; combinatoric_ratelaws=combinatoric_ratelaws)
493+
systems = convert.(JumpSystem, get_systems(rs))
494+
JumpSystem(eqs, get_iv(rs), get_states(rs), get_ps(rs), name=name, systems=systems)
469495
end
470496

471497

@@ -474,35 +500,35 @@ end
474500

475501
# ODEProblem from AbstractReactionNetwork
476502
function DiffEqBase.ODEProblem(rs::ReactionSystem, u0, tspan, p=DiffEqBase.NullParameters(), args...; kwargs...)
477-
return ODEProblem(convert(ODESystem,rs),u0,tspan,p, args...; kwargs...)
503+
return ODEProblem(convert(ODESystem,rs; kwargs...),u0,tspan,p, args...; kwargs...)
478504
end
479505

480506
# NonlinearProblem from AbstractReactionNetwork
481507
function DiffEqBase.NonlinearProblem(rs::ReactionSystem, u0, p=DiffEqBase.NullParameters(), args...; kwargs...)
482-
return NonlinearProblem(convert(NonlinearSystem,rs), u0, p, args...; kwargs...)
508+
return NonlinearProblem(convert(NonlinearSystem,rs; kwargs...), u0, p, args...; kwargs...)
483509
end
484510

485511

486512
# SDEProblem from AbstractReactionNetwork
487513
function DiffEqBase.SDEProblem(rs::ReactionSystem, u0, tspan, p=DiffEqBase.NullParameters(), args...; noise_scaling=nothing, kwargs...)
488-
sde_sys = convert(SDESystem,rs,noise_scaling=noise_scaling)
489-
p_matrix = zeros(length(states(rs)), length(equations(rs)))
514+
sde_sys = convert(SDESystem,rs;noise_scaling=noise_scaling, kwargs...)
515+
p_matrix = zeros(length(get_states(rs)), length(get_eqs(rs)))
490516
return SDEProblem(sde_sys,u0,tspan,p,args...; noise_rate_prototype=p_matrix,kwargs...)
491517
end
492518

493519
# DiscreteProblem from AbstractReactionNetwork
494520
function DiffEqBase.DiscreteProblem(rs::ReactionSystem, u0, tspan::Tuple, p=DiffEqBase.NullParameters(), args...; kwargs...)
495-
return DiscreteProblem(convert(JumpSystem,rs), u0,tspan,p, args...; kwargs...)
521+
return DiscreteProblem(convert(JumpSystem,rs; kwargs...), u0,tspan,p, args...; kwargs...)
496522
end
497523

498524
# JumpProblem from AbstractReactionNetwork
499525
function DiffEqJump.JumpProblem(rs::ReactionSystem, prob, aggregator, args...; kwargs...)
500-
return JumpProblem(convert(JumpSystem,rs), prob, aggregator, args...; kwargs...)
526+
return JumpProblem(convert(JumpSystem,rs; kwargs...), prob, aggregator, args...; kwargs...)
501527
end
502528

503529
# SteadyStateProblem from AbstractReactionNetwork
504530
function DiffEqBase.SteadyStateProblem(rs::ReactionSystem, u0, p=DiffEqBase.NullParameters(), args...; kwargs...)
505-
return SteadyStateProblem(ODEFunction(convert(ODESystem,rs)),u0,p, args...; kwargs...)
531+
return SteadyStateProblem(ODEFunction(convert(ODESystem,rs; kwargs...)),u0,p, args...; kwargs...)
506532
end
507533

508534
# determine which species a reaction depends on

test/reactionsystem_components.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using ModelingToolkit, LinearAlgebra, OrdinaryDiffEq, Test
2+
MT = ModelingToolkit
3+
4+
# Repressilator model
5+
@parameters t α₀ α K n δ β μ
6+
@variables m(t) P(t) R(t)
7+
rxs = [Reaction(α₀, nothing, [m]),
8+
Reaction/ (1 + (R/K)^n), nothing, [m]),
9+
Reaction(δ, [m], nothing),
10+
Reaction(β, [m], [m,P]),
11+
Reaction(μ, [P], nothing)
12+
]
13+
14+
specs = [m,P,R]
15+
pars = [α₀,α,K,n,δ,β,μ]
16+
@named rs = ReactionSystem(rxs, t, specs, pars)
17+
18+
# using ODESystem components
19+
@named os₁ = convert(ODESystem, rs)
20+
@named os₂ = convert(ODESystem, rs)
21+
@named os₃ = convert(ODESystem, rs)
22+
connections = [os₁.R ~ os₃.P,
23+
os₂.R ~ os₁.P,
24+
os₃.R ~ os₂.P]
25+
@named connected = ODESystem(connections, t, [], [], systems=[os₁,os₂,os₃])
26+
oderepressilator = structural_simplify(connected)
27+
28+
pvals = [os₁.α₀ => 5e-4,
29+
os₁.α => .5,
30+
os₁.K => 40.0,
31+
os₁.n => 2,
32+
os₁.δ => (log(2)/120),
33+
os₁.β => (20*log(2)/120),
34+
os₁.μ => (log(2)/600),
35+
os₂.α₀ => 5e-4,
36+
os₂.α => .5,
37+
os₂.K => 40.0,
38+
os₂.n => 2,
39+
os₂.δ => (log(2)/120),
40+
os₂.β => (20*log(2)/120),
41+
os₂.μ => (log(2)/600),
42+
os₃.α₀ => 5e-4,
43+
os₃.α => .5,
44+
os₃.K => 40.0,
45+
os₃.n => 2,
46+
os₃.δ => (log(2)/120),
47+
os₃.β => (20*log(2)/120),
48+
os₃.μ => (log(2)/600)]
49+
u₀ = [os₁.m => 0.0, os₁.P => 20.0, os₂.m => 0.0, os₂.P => 0.0, os₃.m => 0.0, os₃.P => 0.0]
50+
tspan = (0.0, 100000.0)
51+
oprob = ODEProblem(oderepressilator, u₀, tspan, pvals)
52+
sol = solve(oprob, Tsit5())
53+
54+
# hardcoded network
55+
function repress!(f, y, p, t)
56+
α = p.α; α₀ = p.α₀; β = p.β; δ = p.δ; μ = p.μ; K = p.K; n = p.n
57+
f[1] = α / (1 + (y[6] / K)^n) - δ * y[1] + α₀
58+
f[2] = α / (1 + (y[4] / K)^n) - δ * y[2] + α₀
59+
f[3] = α / (1 + (y[5] / K)^n) - δ * y[3] + α₀
60+
f[4] = β * y[1] - μ * y[4]
61+
f[5] = β * y[2] - μ * y[5]
62+
f[6] = β * y[3] - μ * y[6]
63+
nothing
64+
end
65+
ps = (α₀=5e-4, α=.5, K=40.0, n=2, δ=(log(2)/120), β=(20*log(2)/120), μ=(log(2)/600))
66+
u0 = [0.0,0.0,0.0,20.0,0.0,0.0]
67+
oprob2 = ODEProblem(repress!, u0, tspan, ps)
68+
sol2 = solve(oprob2, Tsit5())
69+
tvs = 0:1:tspan[end]
70+
71+
indexof(sym,syms) = findfirst(isequal(sym),syms)
72+
i = indexof(os₁.P, states(oderepressilator))
73+
@test all(isapprox(u[1],u[2],atol=1e-4) for u in zip(sol(tvs, idxs=2), sol2(tvs, idxs=4)))
74+
75+
# using ReactionSystem components
76+
77+
# @named rs₁ = ReactionSystem(rxs, t, specs, pars)
78+
# @named rs₂ = ReactionSystem(rxs, t, specs, pars)
79+
# @named rs₃ = ReactionSystem(rxs, t, specs, pars)
80+
# connections = [rs₁.R ~ rs₃.P,
81+
# rs₂.R ~ rs₁.P,
82+
# rs₃.R ~ rs₂.P]
83+
# @named csys = ODESystem(connections, t, [], [])
84+
# @named repressilator = ReactionSystem(t; systems=[csys,rs₁,rs₂,rs₃])
85+
# @named oderepressilator2 = convert(ODESystem, repressilator)
86+
# sys2 = structural_simplify(oderepressilator2) # FAILS currently

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using SafeTestsets, Test
1515
@safetestset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end
1616
@safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end
1717
@safetestset "ReactionSystem Test" begin include("reactionsystem.jl") end
18+
@safetestset "ReactionSystem Test" begin include("reactionsystem_components.jl") end
1819
@safetestset "JumpSystem Test" begin include("jumpsystem.jl") end
1920
@safetestset "ControlSystem Test" begin include("controlsystem.jl") end
2021
@safetestset "Domain Test" begin include("domains.jl") end

0 commit comments

Comments
 (0)