Skip to content

Commit 5c11fbc

Browse files
committed
using JumpProblem rng
1 parent 782c430 commit 5c11fbc

File tree

3 files changed

+262
-12
lines changed

3 files changed

+262
-12
lines changed
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# SimpleSplitTauLeaping Implementation Plan
2+
Date: 2025-08-25
3+
4+
## Overview
5+
Implementation of a new tau-leaping integrator, `SimpleSplitTauLeaping`, that uses first-order operator splitting to evaluate one jump at a time. This initial implementation focuses on MassActionJumps only, following the type-stability practices used in Direct() and other constant rate aggregators.
6+
7+
## 1. Algorithm Definition
8+
Add to `src/simple_regular_solve.jl`:
9+
```julia
10+
struct SimpleSplitTauLeaping <: DiffEqBase.DEAlgorithm end
11+
```
12+
13+
## 2. Validation Function
14+
```julia
15+
function validate_massjump_splitting_inputs(jump_prob::JumpProblem, alg)
16+
if !(jump_prob.aggregator isa PureLeaping)
17+
@warn "When using $alg, please pass PureLeaping() as the aggregator..."
18+
end
19+
# Only MassActionJumps allowed
20+
isempty(jump_prob.jump_callback.continuous_callbacks) &&
21+
isempty(jump_prob.jump_callback.discrete_callbacks) &&
22+
isempty(jump_prob.constant_jumps) &&
23+
isempty(jump_prob.variable_jumps) &&
24+
jump_prob.regular_jump === nothing &&
25+
get_num_majumps(jump_prob.massaction_jump) > 0
26+
end
27+
```
28+
29+
## 3. Core Implementation (Allocation-Free)
30+
```julia
31+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleSplitTauLeaping;
32+
seed = nothing,
33+
dt = error("dt is required for SimpleSplitTauLeaping."))
34+
35+
validate_massjump_splitting_inputs(jump_prob, alg) ||
36+
error("SimpleSplitTauLeaping currently only supports MassActionJumps with PureLeaping")
37+
38+
prob = jump_prob.prob
39+
rng = DEFAULT_RNG
40+
(seed !== nothing) && seed!(rng, seed)
41+
42+
# Extract MassActionJumps
43+
ma_jumps = jump_prob.massaction_jump
44+
num_jumps = get_num_majumps(ma_jumps)
45+
46+
# Pre-allocate
47+
u0 = copy(prob.u0)
48+
u_work = similar(u0) # Working state vector
49+
50+
tspan = prob.tspan
51+
p = prob.p
52+
n = Int((tspan[2] - tspan[1]) / dt) + 1
53+
u = Vector{typeof(u0)}(undef, n)
54+
u[1] = u0
55+
t = tspan[1]:dt:tspan[2]
56+
57+
# Main loop - operator splitting with individual jump execution
58+
for i in 2:n
59+
copy!(u_work, u[i-1])
60+
61+
# Split tau-leaping: evaluate and execute each jump separately
62+
@inbounds for j in 1:num_jumps
63+
rate = evalrxrate(u_work, j, ma_jumps)
64+
num_firings = pois_rand(rng, rate * dt)
65+
66+
# Execute this jump num_firings times
67+
for _ in 1:num_firings
68+
executerx!(u_work, j, ma_jumps)
69+
end
70+
end
71+
72+
u[i] = copy(u_work)
73+
end
74+
75+
sol = DiffEqBase.build_solution(prob, alg, t, u,
76+
calculate_error = false,
77+
interp = DiffEqBase.ConstantInterpolation(t, u))
78+
end
79+
```
80+
81+
## 4. Test Implementation
82+
Create comprehensive tests with 5% accuracy target:
83+
84+
```julia
85+
@testset "SimpleSplitTauLeaping MassActionJump Tests" begin
86+
using Statistics, JumpProcesses, DiffEqBase, Random
87+
88+
# Test 1: Birth-death process via MassActionJumps
89+
@testset "Birth-Death MassActionJump" begin
90+
# ∅ → X (birth), X → ∅ (death)
91+
reactant_stoich = [Vector{Pair{Int,Int}}(), # ∅ → X
92+
[1 => 1]] # X → ∅
93+
net_stoich = [[1 => 1], # ∅ → X adds one
94+
[1 => -1]] # X → ∅ removes one
95+
rates = [1.0, 0.1] # birth rate, death rate
96+
97+
ma_jumps = MassActionJump(rates, reactant_stoich, net_stoich)
98+
99+
u0 = [10]
100+
tspan = (0.0, 100.0)
101+
dprob = DiscreteProblem(u0, tspan)
102+
103+
# Run ensembles for statistics
104+
n_traj = 10000
105+
106+
# Direct method reference
107+
jprob_direct = JumpProblem(dprob, Direct(), ma_jumps)
108+
ensembleprob_direct = EnsembleProblem(jprob_direct)
109+
sol_direct = solve(ensembleprob_direct, SSAStepper(),
110+
EnsembleThreads(), trajectories=n_traj)
111+
112+
# SimpleSplitTauLeaping with small dt
113+
jprob_split = JumpProblem(dprob, PureLeaping(), ma_jumps)
114+
115+
function run_split_ensemble(prob, n_traj, dt, seed)
116+
sols = Vector{Any}(undef, n_traj)
117+
for i in 1:n_traj
118+
sols[i] = solve(prob, SimpleSplitTauLeaping(),
119+
dt=dt, seed=seed+i)
120+
end
121+
return sols
122+
end
123+
124+
sol_split = run_split_ensemble(jprob_split, n_traj, 0.001, 12345)
125+
126+
# Extract final values
127+
direct_final = [sol.u[end][1] for sol in sol_direct]
128+
split_final = [sol.u[end][1] for sol in sol_split]
129+
130+
# Test mean and variance (5% relative accuracy)
131+
@test mean(split_final) mean(direct_final) rtol=0.05
132+
@test var(split_final) var(direct_final) rtol=0.05
133+
end
134+
135+
# Test 2: Simple reaction A + B → C
136+
@testset "A + B → C Reaction" begin
137+
reactant_stoich = [[1 => 1, 2 => 1]] # A + B
138+
net_stoich = [[1 => -1, 2 => -1, 3 => 1]] # -A -B +C
139+
rates = [0.001]
140+
141+
ma_jumps = MassActionJump(rates, reactant_stoich, net_stoich)
142+
143+
u0 = [100, 100, 0]
144+
tspan = (0.0, 10.0)
145+
dprob = DiscreteProblem(u0, tspan)
146+
147+
n_traj = 5000
148+
149+
# Direct reference
150+
jprob_direct = JumpProblem(dprob, Direct(), ma_jumps)
151+
ensembleprob_direct = EnsembleProblem(jprob_direct)
152+
sol_direct = solve(ensembleprob_direct, SSAStepper(),
153+
EnsembleThreads(), trajectories=n_traj)
154+
155+
# SimpleSplitTauLeaping
156+
jprob_split = JumpProblem(dprob, PureLeaping(), ma_jumps)
157+
sol_split = run_split_ensemble(jprob_split, n_traj, 0.0001, 54321)
158+
159+
# Compare means of all species at final time
160+
for species in 1:3
161+
direct_vals = [sol.u[end][species] for sol in sol_direct]
162+
split_vals = [sol.u[end][species] for sol in sol_split]
163+
164+
@test mean(split_vals) mean(direct_vals) rtol=0.05
165+
@test var(split_vals) var(direct_vals) rtol=0.05
166+
end
167+
168+
# Check conservation
169+
for sol in sol_split
170+
@test sum(sol.u[end]) == sum(u0)
171+
end
172+
end
173+
174+
# Test 3: Lotka-Volterra predator-prey
175+
@testset "Lotka-Volterra System" begin
176+
# X → 2X (prey birth)
177+
# X + Y → 2Y (predation)
178+
# Y → ∅ (predator death)
179+
reactant_stoich = [[1 => 1], # X
180+
[1 => 1, 2 => 1], # X + Y
181+
[2 => 1]] # Y
182+
net_stoich = [[1 => 1], # X births
183+
[1 => -1, 2 => 1], # X dies, Y births
184+
[2 => -1]] # Y dies
185+
rates = [1.0, 0.001, 1.0]
186+
187+
ma_jumps = MassActionJump(rates, reactant_stoich, net_stoich)
188+
189+
u0 = [100, 100] # Initial prey and predator
190+
tspan = (0.0, 20.0)
191+
dprob = DiscreteProblem(u0, tspan)
192+
193+
n_traj = 5000
194+
195+
# Direct
196+
jprob_direct = JumpProblem(dprob, Direct(), ma_jumps)
197+
ensembleprob_direct = EnsembleProblem(jprob_direct)
198+
sol_direct = solve(ensembleprob_direct, SSAStepper(),
199+
EnsembleThreads(), trajectories=n_traj)
200+
201+
# SimpleSplitTauLeaping with very small dt for accuracy
202+
jprob_split = JumpProblem(dprob, PureLeaping(), ma_jumps)
203+
sol_split = run_split_ensemble(jprob_split, n_traj, 0.0001, 99999)
204+
205+
# Sample at multiple time points
206+
test_times = [5.0, 10.0, 15.0, 20.0]
207+
for test_t in test_times
208+
# Find closest time index
209+
t_idx_direct = findfirst(t -> t >= test_t, sol_direct[1].t)
210+
t_idx_split = findfirst(t -> t >= test_t, sol_split[1].t)
211+
212+
for species in 1:2
213+
direct_vals = [sol.u[t_idx_direct][species] for sol in sol_direct]
214+
split_vals = [sol.u[t_idx_split][species] for sol in sol_split]
215+
216+
@test mean(split_vals) mean(direct_vals) rtol=0.05
217+
end
218+
end
219+
end
220+
end
221+
```
222+
223+
## 5. Key Simplifications
224+
225+
- **MassActionJumps only**: No need for FunctionWrappers or type dispatch
226+
- **Direct rate evaluation**: Use existing `evalrxrate` and `executerx!`
227+
- **Minimal allocations**: Only `u_work` vector for in-place updates
228+
- **Simple validation**: Check only for MassActionJumps presence
229+
230+
## 6. Performance Notes
231+
232+
- Type stable since only one jump type
233+
- Cache-friendly sequential access
234+
- Minimal branching in inner loop
235+
- In-place operations throughout
236+
237+
## 7. Implementation Strategy
238+
239+
1. Add the struct and solve method to `src/simple_regular_solve.jl`
240+
2. Add export in `src/JumpProcesses.jl`
241+
3. Create test file or add to existing test suite
242+
4. Verify 5% accuracy requirement is met across different models
243+
5. Document the method's operator splitting approach
244+
245+
## 8. Future Extensions
246+
247+
Once the basic MassActionJump implementation is working:
248+
- Add support for ConstantRateJumps using FunctionWrappers
249+
- Add support for VariableRateJumps
250+
- Consider adaptive time-stepping
251+
- Optimize for specific reaction network structures

src/simple_regular_solve.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@ function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
1515
end
1616

1717
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
18-
seed = nothing,
19-
dt = error("dt is required for SimpleTauLeaping."))
18+
seed = nothing, dt = error("dt is required for SimpleTauLeaping."))
2019
validate_pure_leaping_inputs(jump_prob, alg) ||
2120
error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only non-RegularJumps.")
22-
prob = jump_prob.prob
23-
rng = DEFAULT_RNG
21+
22+
@unpack prob, rng = jump_prob
2423
(seed !== nothing) && seed!(rng, seed)
2524

2625
rj = jump_prob.regular_jump

test/regular_jumps.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
3939
maj = MassActionJump(rates, reactant_stoich, net_stoich)
4040

4141
# Test PureLeaping JumpProblem creation
42-
jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj))
42+
jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj); rng)
4343
@test jp_pure.aggregator isa PureLeaping
4444
@test jp_pure.discrete_jump_aggregation === nothing
4545
@test jp_pure.massaction_jump !== nothing
@@ -50,7 +50,7 @@ sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
5050
affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
5151
crj = ConstantRateJump(rate, affect!)
5252

53-
jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj))
53+
jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj); rng)
5454
@test jp_pure_crj.aggregator isa PureLeaping
5555
@test jp_pure_crj.discrete_jump_aggregation === nothing
5656
@test length(jp_pure_crj.constant_jumps) == 1
@@ -60,7 +60,7 @@ sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
6060
vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
6161
vrj = VariableRateJump(vrate, vaffect!)
6262

63-
jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj))
63+
jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj); rng)
6464
@test jp_pure_vrj.aggregator isa PureLeaping
6565
@test jp_pure_vrj.discrete_jump_aggregation === nothing
6666
@test length(jp_pure_vrj.variable_jumps) == 1
@@ -80,15 +80,15 @@ sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
8080

8181
regj = RegularJump(rj_rate, rj_c, 1)
8282

83-
jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj))
83+
jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj); rng)
8484
@test jp_pure_regj.aggregator isa PureLeaping
8585
@test jp_pure_regj.discrete_jump_aggregation === nothing
8686
@test jp_pure_regj.regular_jump !== nothing
8787

8888
# Test mixed jump types
8989
mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,),
9090
variable_jumps = (vrj,), regular_jumps = regj)
91-
jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps)
91+
jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps; rng)
9292
@test jp_pure_mixed.aggregator isa PureLeaping
9393
@test jp_pure_mixed.discrete_jump_aggregation === nothing
9494
@test jp_pure_mixed.massaction_jump !== nothing
@@ -99,14 +99,14 @@ sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
9999
# Test spatial system error
100100
spatial_sys = CartesianGrid((2, 2))
101101
hopping_consts = [1.0]
102-
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj);
102+
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng,
103103
spatial_system = spatial_sys)
104-
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj);
104+
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng,
105105
hopping_constants = hopping_consts)
106106

107107
# Test MassActionJump with parameter mapping
108108
maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2])
109-
jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params))
109+
jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng)
110110
scaled_rates = [p[1], p[2]/2]
111111
@test jp_params.massaction_jump.scaled_rates == scaled_rates
112112
end

0 commit comments

Comments
 (0)