Skip to content

Commit 847df02

Browse files
authored
Merge pull request #70 from augustinas1/u0map
Support initial state mapping and streamline ODEProblem
2 parents d2cb55c + 48d39dc commit 847df02

33 files changed

+258
-240
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ TupleTools = "1.2.0"
3434
julia = "1.6"
3535

3636
[extras]
37+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
3738
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
3839
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3940
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4041

4142
[targets]
42-
test = ["Test", "SafeTestsets", "JumpProcesses"]
43+
test = ["Test", "SafeTestsets", "JumpProcesses", "OrdinaryDiffEqTsit5"]

docs/src/api/momentclosure_api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ bernoulli_moment_eqs
3131
```@docs
3232
moment_closure
3333
ClosedMomentEquations
34+
```
35+
36+
## Solving Moment Equations
37+
38+
```@docs
3439
deterministic_IC
40+
ODEProblem(::MomentEquations, ::Tuple, ::Tuple, ::Tuple)
3541
```
3642

3743
## Basic Accessor Functions
@@ -43,6 +49,7 @@ We also define a few accessor functions that return system information from the
4349
* `ModelingToolkit.get_iv(sys::MomentEquations)`: The independent variable used in the system.
4450
* `ModelingToolkit.get_ps(sys::MomentEquations)`: The parameters of the system.
4551
* `ModelingToolkit.unknowns(sys::MomentEquations)`: The set of unknowns (moments) in the equations.
52+
* `Catalyst.speciesmap(sys::MomentEquations)`: The dictionary mapping the chemical species in a `Catalyst.ReactionSystem` to their index within the corresponding moment equations.
4653
* `MomentClosure.get_closure(sys::ClosedMomentEquations)`: The dictionary of moment closure functions for each higher order moment.
4754

4855
## [Displaying Equations and Closures](@id visualisation_api)

docs/src/tutorials/LMA_example.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,12 @@ Note that the results agree with Eqs. (1) and (2) (after a corresponding substit
7171
```julia
7272
using OrdinaryDiffEq, Sundials, Plots
7373

74-
# [g, p] ordered as in `speciesmap(rn_nonlinear)`
75-
u₀ = [1.0, 0.001]
76-
p = Dict(:σ_b => 0.004, :σ_u => 0.25, :ρ_b => 25.0, :ρ_u => 60.0)
74+
u0map = [:g => 1.0, :p => 0.001]
75+
pmap = Dict(:σ_b => 0.004, :σ_u => 0.25, :ρ_b => 25.0, :ρ_u => 60.0)
7776
tspan = (0., 15.)
7877
dt = 0.1
7978

80-
u₀map = deterministic_IC(u₀, LMA_eqs)
81-
oprob_LMA = ODEProblem(LMA_eqs, u₀map, tspan, pmap)
79+
oprob_LMA = ODEProblem(LMA_eqs, u0map, tspan, pmap)
8280
sol_LMA = solve(oprob_LMA, CVODE_BDF(), saveat=dt)
8381

8482
plot(sol_LMA, idxs=[2], label="LMA", ylabel="⟨p⟩", xlabel="time", fmt="svg")

docs/src/tutorials/P53_system_example.md

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ rn = @reaction_network begin
5959
end
6060

6161
# parameters
62-
p = [:k₁ => 90, :k₂ => 0.002, :k₃ => 1.7, :k₄ => 1.1, :k₅ => 0.93, :k₆ => 0.96, :k₇ => 0.01]
62+
pmap = [:k₁ => 90, :k₂ => 0.002, :k₃ => 1.7, :k₄ => 1.1, :k₅ => 0.93, :k₆ => 0.96, :k₇ => 0.01]
6363

6464
# initial molecule numbers [x, y₀, y]
65-
u₀ = [70, 30, 60]
65+
u0map = [:x=> 70, :y₀ => 30, :y => 60]
6666
```
6767

6868
Let's first simulate the reaction network using SSA in order to have a reference point of the real system dynamics. We choose a relatively long simulation time span in order to clearly see how the molecule numbers converge to their steady-state values and opt for $5 \times 10^4$ SSA realisations:
@@ -73,7 +73,7 @@ tspan = (0., 200.)
7373
# constructing the discrete jump problem using DifferentialEquations
7474
jsys = convert(JumpSystem, rn, combinatoric_ratelaws=false)
7575
jsys = complete(jsys)
76-
dprob = DiscreteProblem(jsys, u₀, tspan, p)
76+
dprob = DiscreteProblem(jsys, u0map, tspan, pmap)
7777

7878
jprob = JumpProblem(jsys, dprob, Direct(), save_positions=(false, false))
7979
ensembleprob = EnsembleProblem(jprob)
@@ -123,10 +123,7 @@ for q in 3:6
123123
eqs = generate_central_moment_eqs(rn, 2, q, combinatoric_ratelaws=false)
124124
for (closure, plt) in zip(closures, plts)
125125
closed_eqs = moment_closure(eqs, closure)
126-
127-
u₀map = deterministic_IC(u₀, closed_eqs)
128-
oprob = ODEProblem(closed_eqs, u₀map, tspan, p)
129-
126+
oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)
130127
sol = solve(oprob, Tsit5(), saveat=0.1)
131128
plt = plot!(plt, sol, idxs=[1], lw=3, label = "q = "*string(q))
132129
end
@@ -171,12 +168,9 @@ for q in [4,6]
171168
eqs = generate_central_moment_eqs(rn, 2, q, combinatoric_ratelaws=false)
172169
for closure in closures
173170
closed_eqs = moment_closure(eqs, closure)
174-
175-
u₀map = deterministic_IC(u₀, closed_eqs)
176-
oprob = ODEProblem(closed_eqs, u₀map, tspan, p)
171+
oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)
177172
sol = solve(oprob, Tsit5(), saveat=0.1)
178-
179-
# index of M₂₀₀ can be checked with `u₀map` or `closed_eqs.odes.states`
173+
# index of M₂₀₀ can be checked with `unknowns(closed_eqs)`
180174
plt = plot!(plt, sol, idxs=[4], lw=3, label = closure*" q = "*string(q))
181175
end
182176
end
@@ -202,11 +196,8 @@ for (q, plt_m, plt_v) in zip(q_vals, plt_means, plt_vars)
202196

203197
eqs = generate_central_moment_eqs(rn, 3, q, combinatoric_ratelaws=false)
204198
for closure in closures
205-
206199
closed_eqs = moment_closure(eqs, closure)
207-
208-
u₀map = deterministic_IC(u₀, closed_eqs)
209-
oprob = ODEProblem(closed_eqs, u₀map, tspan, p)
200+
oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)
210201

211202
sol = solve(oprob, Tsit5(), saveat=0.1)
212203
plt_m = plot!(plt_m, sol, idxs=[1], label = closure)
@@ -276,9 +267,7 @@ oprobs = Dict()
276267

277268
for closure in closures
278269
closed_eqs = moment_closure(eqs, closure)
279-
280-
u₀map = deterministic_IC(u₀, closed_eqs)
281-
oprobs[closure] = ODEProblem(closed_eqs, u₀map, tspan, p)
270+
oprobs[closure] = ODEProblem(closed_eqs, u0map, tspan, pmap)
282271
sol = solve(oprobs[closure], Tsit5(), saveat=0.1)
283272

284273
plt = plot!(plt, sol, idxs=[1], label = closure)

docs/src/tutorials/common_issues.md

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ rn = @reaction_network begin
1313
(c₃*Ω, c₄), 0 X
1414
end
1515

16-
p = [:c₁ => 0.9, :c₂ => 2, :c₃ => 1, :c₄ => 1, => 100]
17-
u₀ = [1, 1]
16+
pmap = [:c₁ => 0.9, :c₂ => 2, :c₃ => 1, :c₄ => 1, => 100]
17+
u0map = [:X => 1, :Y => 1]
1818
tspan = (0., 100.)
1919

2020
raw_eqs = generate_raw_moment_eqs(rn, 2, combinatoric_ratelaws=false)
@@ -23,8 +23,7 @@ As we have seen earlier, second-order moment expansion using normal closure appr
2323
```julia
2424
closed_raw_eqs = moment_closure(raw_eqs, "zero")
2525

26-
u₀map = deterministic_IC(u₀, closed_raw_eqs)
27-
oprob = ODEProblem(closed_raw_eqs, u₀map, tspan, p)
26+
oprob = ODEProblem(closed_raw_eqs, u0map, tspan, pmap)
2827
sol = solve(oprob, Tsit5(), saveat=0.1)
2928

3029
plot(sol, idxs=[1,2], lw=2)
@@ -37,8 +36,7 @@ Let's apply log-normal closure next:
3736
```julia
3837
closed_raw_eqs = moment_closure(raw_eqs, "log-normal")
3938

40-
u₀map = deterministic_IC(u₀, closed_raw_eqs)
41-
oprob = ODEProblem(closed_raw_eqs, u₀map, tspan, p)
39+
oprob = ODEProblem(closed_raw_eqs, u0map, tspan, pmap)
4240
sol = solve(oprob, Tsit5(), saveat=0.1)
4341

4442
plot(sol, idxs=[1,2], lw=2, legend=:bottomright)
@@ -52,8 +50,7 @@ Normal closure is also quite fragile. This can be seen by simply including the c
5250
raw_eqs = generate_raw_moment_eqs(rn, 2, combinatoric_ratelaws=true)
5351
closed_raw_eqs = moment_closure(raw_eqs, "normal")
5452

55-
u₀map = deterministic_IC(u₀, closed_raw_eqs)
56-
oprob = ODEProblem(closed_raw_eqs, u₀map, tspan, p)
53+
oprob = ODEProblem(closed_raw_eqs, u0map, tspan, pmap)
5754
sol = solve(oprob, Tsit5(), saveat=0.1)
5855

5956
plot(sol, idxs=[1,2], lw=2)
@@ -65,8 +62,7 @@ Nevertheless, this can be improved upon by increasing the order of moment expans
6562
raw_eqs = generate_raw_moment_eqs(rn, 3, combinatoric_ratelaws=true)
6663
closed_raw_eqs = moment_closure(raw_eqs, "normal")
6764

68-
u₀map = deterministic_IC(u₀, closed_raw_eqs)
69-
oprob = ODEProblem(closed_raw_eqs, u₀map, tspan, p)
65+
oprob = ODEProblem(closed_raw_eqs, u0map, tspan, pmap)
7066
sol = solve(oprob, Tsit5(), saveat=0.1)
7167

7268
plot(sol, idxs=[1,2], lw=2, legend=:bottomright)
@@ -78,8 +74,7 @@ Some dampening in the system is now visible. Increasing the expansion order to `
7874
raw_eqs = generate_raw_moment_eqs(rn, 4, combinatoric_ratelaws=true)
7975
closed_raw_eqs = moment_closure(raw_eqs, "normal")
8076

81-
u₀map = deterministic_IC(u₀, closed_raw_eqs)
82-
oprob = ODEProblem(closed_raw_eqs, u₀map, tspan, p)
77+
oprob = ODEProblem(closed_raw_eqs, u0map, tspan, pmap)
8378
sol = solve(oprob, Tsit5(), saveat=0.1)
8479

8580
plot(sol, idxs=[1,2], lw=2)
@@ -91,8 +86,7 @@ For dessert, we consider unphysical divergent trajectories—a frequent problem
9186
raw_eqs = generate_raw_moment_eqs(rn, 2, combinatoric_ratelaws=true)
9287
closed_raw_eqs = moment_closure(raw_eqs, "log-normal")
9388

94-
u₀map = deterministic_IC(u₀, closed_raw_eqs)
95-
oprob = ODEProblem(closed_raw_eqs, u₀map, tspan, p)
89+
oprob = ODEProblem(closed_raw_eqs, u0map, tspan, pmap)
9690
sol = solve(oprob, Rodas4P(), saveat=0.1)
9791

9892
plot(sol, idxs=[1,2], lw=2)
@@ -104,8 +98,7 @@ In contrast to normal closure, increasing the expansion order makes the problem
10498
raw_eqs = generate_raw_moment_eqs(rn, 3, combinatoric_ratelaws=true)
10599
closed_raw_eqs = moment_closure(raw_eqs, "log-normal")
106100

107-
u₀map = deterministic_IC(u₀, closed_raw_eqs)
108-
oprob = ODEProblem(closed_raw_eqs, u₀map, tspan, p)
101+
oprob = ODEProblem(closed_raw_eqs, u0map, tspan, pmap)
109102
sol = solve(oprob, Rodas4P(), saveat=0.1)
110103

111104
plot(sol, idxs=[1,2], lw=2)

docs/src/tutorials/derivative_matching_example.md

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ rn = @reaction_network begin
1818
end
1919

2020
# parameter values
21-
p = [:c₁ => 1.0, :c₂ => 1.0]
21+
pmap = [:c₁ => 1.0, :c₂ => 1.0]
2222
# initial conditions
23-
u0 = [20, 10]
23+
u0map = [:x₁ => 20, :x₂ => 10]
2424
# time interval to solve on
2525
tspan = (0., 0.5)
2626
```
@@ -46,8 +46,7 @@ Note that all closure functions are consistent with the ones shown in Table II o
4646
```julia
4747
using OrdinaryDiffEqTsit5
4848

49-
u0map = deterministic_IC(u0, dm2_eqs) # assuming deterministic initial conditions
50-
oprob = ODEProblem(dm2_eqs, u0map, tspan, p)
49+
oprob = ODEProblem(dm2_eqs, u0map, tspan, pmap)
5150
dm2_sol = solve(oprob, Tsit5(), saveat=0.01)
5251
```
5352
Now the question is how can we extract the time evolution of the cumulant $\kappa_{03}$. Firstly, note that using the standard moment relationships it can be expressed in terms of raw moments as:
@@ -108,8 +107,7 @@ unknowns(dm3_eqs.odes)
108107
```
109108
and solve the moment equations, computing the required cumulant:
110109
```julia
111-
u0map = deterministic_IC(u0, dm3_eqs)
112-
oprob = ODEProblem(dm3_eqs, u0map, tspan, p)
110+
oprob = ODEProblem(dm3_eqs, u0map, tspan, pmap)
113111
dm3_sol = solve(oprob, Tsit5(), saveat=0.01, abstol=1e-8, reltol=1e-8)
114112

115113
μ₀₁ = dm3_sol[2,:]
@@ -122,8 +120,7 @@ Note that we could have also obtained $\kappa_{03}$ estimate in an easier way by
122120
central_eqs3 = generate_central_moment_eqs(rn, 3)
123121
dm3_central_eqs = moment_closure(central_eqs3, "derivative matching")
124122

125-
u0map = deterministic_IC(u0, dm3_central_eqs)
126-
oprob = ODEProblem(dm3_central_eqs, u0map, tspan, p)
123+
oprob = ODEProblem(dm3_central_eqs, u0map, tspan, pmap)
127124
dm3_central_sol = solve(oprob, Tsit5(), saveat=0.01, abstol=1e-8, reltol=1e-8)
128125

129126
# check that the two estimates are equivalent
@@ -136,7 +133,7 @@ The last ingredient we need for a proper comparison between the second and third
136133
```julia
137134
using JumpProcesses
138135

139-
dprob = DiscreteProblem(rn, u0, tspan, p)
136+
dprob = DiscreteProblem(rn, u0map, tspan, pmap)
140137
jprob = JumpProblem(rn, dprob, Direct(), save_positions=(false, false))
141138

142139
ensembleprob = EnsembleProblem(jprob)

docs/src/tutorials/geometric_reactions+conditional_closures.md

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ pmap = [:k_on => k_on_val,
142142
:γ_p => γ_p_val,
143143
:b => mean_b]
144144

145-
# initial gene state and protein number, order [g, p]
146-
u₀ = [1, 1]
145+
# initial gene state and protein number
146+
u0map = [:g => 1, :P => 1]
147147

148148
# time interval to solve on
149149
tspan = (0., 6.0)
@@ -155,7 +155,7 @@ jsys = convert(JumpSystem, rn; combinatoric_ratelaws=false)
155155
jsys = complete(jsys)
156156

157157
# create a discrete problem setting the simulation parameters
158-
dprob = DiscreteProblem(u₀, tspan, pmap)
158+
dprob = DiscreteProblem(u0map, tspan, pmap)
159159

160160
# create a JumpProblem compatible with ReactionSystemMod
161161
jprob = JumpProblem(rn, dprob, Direct(), save_positions=(false, false))
@@ -174,9 +174,6 @@ We continue to solve the moment equations for each closure:
174174
plt_m = plot() # plot mean protein number
175175
plt_std = plot() # plot ssd of protein number
176176

177-
# construct the initial molecule number mapping
178-
u₀map = deterministic_IC(u₀, dm_eqs)
179-
180177
# solve moment ODEs for each closure and plot the results
181178
for closure in ["normal", "derivative matching",
182179
"conditional gaussian", "conditional derivative matching"]
@@ -185,7 +182,7 @@ for closure in ["normal", "derivative matching",
185182
closed_eqs = moment_closure(eqs, closure, binary_vars)
186183

187184
# solve the system of moment ODEs
188-
oprob = ODEProblem(closed_eqs, u₀map, tspan, pmap)
185+
oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)
189186
sol = solve(oprob, AutoTsit5(Rosenbrock23()), saveat=0.01)
190187

191188
# μ₀₁ is 2nd and μ₀₂ is 4th element in sol
@@ -273,7 +270,7 @@ pmap = [:kx_on => kx_on_val,
273270
:b_y => mean_b_y]
274271

275272
# initial gene state and protein number, order [g_x, g_y, x, y]
276-
u₀ = [1, 1, 1, 1]
273+
u0map = [:g_x => 1, :g_y => 1, :x => 1, :y => 1]
277274

278275
# time interval to solve on
279276
tspan = (0., 12.0)
@@ -284,7 +281,7 @@ We can run SSA as follows:
284281
```julia
285282
jsys = convert(JumpSystem, rn, combinatoric_ratelaws=false)
286283
jsys = complete(jsys)
287-
dprob = DiscreteProblem(jsys, u₀, tspan, pmap)
284+
dprob = DiscreteProblem(jsys, u0map, tspan, pmap)
288285
jprob = JumpProblem(jsys, dprob, Direct(), save_positions=(false, false))
289286

290287
ensembleprob = EnsembleProblem(jprob)
@@ -302,9 +299,7 @@ plt_std = plot() # plot ssd of activator protein number
302299
for closure in ["derivative matching", "conditional derivative matching"]
303300

304301
closed_eqs = moment_closure(eqs, closure, binary_vars)
305-
306-
u₀map = deterministic_IC(u₀, closed_eqs)
307-
oprob = ODEProblem(closed_eqs, u₀map, tspan, pmap)
302+
oprob = ODEProblem(closed_eqs, u0map, tspan, pmap)
308303
sol = solve(oprob, Tsit5(), saveat=0.1)
309304

310305
# μ₀₀₀₁ is the 4th and μ₀₀₀₂ is the 12th element in sol (can check with closed_eqs.odes.states)

docs/src/tutorials/parameter_estimation_SDE.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ Now we approach the same model identification problem via MAs in the hope of cut
109109
using MomentClosure
110110

111111
LV_moments = moment_closure(generate_raw_moment_eqs(LV, 2), "log-normal")
112-
u0map = deterministic_IC(last.(u0), LV_moments)
113-
prob_MA = ODEProblem(LV_moments, u0map, (0.0, Tf), zeros(5))
112+
prob_MA = ODEProblem(LV_moments, u0, (0.0, Tf), zeros(5))
114113
psetter_MA! = setp(prob_MA, (γ1, γ2, γ3, γ4, γ5))
115114

116115
function obj_MA(p)

docs/src/tutorials/time-dependent_propensities.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,13 @@ closed_raw_eqs = moment_closure(raw_eqs, "normal")
4848
p = [:c₁ => 0.9, :c₂ => 2., :c₃ => 1., :c₄ => 1., => 5., => 1., => 40.]
4949

5050
# initial molecule numbers of species [X, Y]
51-
u₀ = [1., 1.]
52-
53-
# deterministic initial conditions
54-
u₀map = deterministic_IC(u₀, closed_raw_eqs)
51+
u0map = [1., 1.]
5552

5653
# time interval to solve one on
5754
tspan = (0., 100.)
5855

5956
# convert the closed raw moment equations into a DifferentialEquations ODEProblem
60-
oprob = ODEProblem(closed_raw_eqs, u₀map, tspan, p)
57+
oprob = ODEProblem(closed_raw_eqs, u0map, tspan, pmap)
6158

6259
# solve using Tsit5 solver
6360
sol = solve(oprob, Tsit5(), saveat=0.2)
@@ -69,7 +66,7 @@ It would be great to compare our results to the true dynamics. Using Differentia
6966
```julia
7067
using JumpProcesses
7168

72-
jinputs = JumpInputs(rn, u₀, tspan, p, combinatoric_ratelaws=false)
69+
jinputs = JumpInputs(rn, u0map, tspan, pmap, combinatoric_ratelaws=false)
7370
jprob = JumpProblem(jinputs, Direct())
7471
```
7572
Note that now we have to provide an ODE solver to `solve` in order to integrate over the time-dependent propensities.

0 commit comments

Comments
 (0)