|
| 1 | +# SimpleSplitTauLeaping Implementation Plan |
| 2 | +Date: 2025-08-25 |
| 3 | + |
| 4 | +## Overview |
| 5 | +Implementation of a new tau-leaping integrator, `SimpleSplitTauLeaping`, that uses first-order operator splitting to evaluate one jump at a time. This initial implementation focuses on MassActionJumps only, following the type-stability practices used in Direct() and other constant rate aggregators. |
| 6 | + |
| 7 | +## 1. Algorithm Definition |
| 8 | +Add to `src/simple_regular_solve.jl`: |
| 9 | +```julia |
| 10 | +struct SimpleSplitTauLeaping <: DiffEqBase.DEAlgorithm end |
| 11 | +``` |
| 12 | + |
| 13 | +## 2. Validation Function |
| 14 | +```julia |
| 15 | +function validate_massjump_splitting_inputs(jump_prob::JumpProblem, alg) |
| 16 | + if !(jump_prob.aggregator isa PureLeaping) |
| 17 | + @warn "When using $alg, please pass PureLeaping() as the aggregator..." |
| 18 | + end |
| 19 | + # Only MassActionJumps allowed |
| 20 | + isempty(jump_prob.jump_callback.continuous_callbacks) && |
| 21 | + isempty(jump_prob.jump_callback.discrete_callbacks) && |
| 22 | + isempty(jump_prob.constant_jumps) && |
| 23 | + isempty(jump_prob.variable_jumps) && |
| 24 | + jump_prob.regular_jump === nothing && |
| 25 | + get_num_majumps(jump_prob.massaction_jump) > 0 |
| 26 | +end |
| 27 | +``` |
| 28 | + |
| 29 | +## 3. Core Implementation (Allocation-Free) |
| 30 | +```julia |
| 31 | +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleSplitTauLeaping; |
| 32 | + seed = nothing, |
| 33 | + dt = error("dt is required for SimpleSplitTauLeaping.")) |
| 34 | + |
| 35 | + validate_massjump_splitting_inputs(jump_prob, alg) || |
| 36 | + error("SimpleSplitTauLeaping currently only supports MassActionJumps with PureLeaping") |
| 37 | + |
| 38 | + prob = jump_prob.prob |
| 39 | + rng = DEFAULT_RNG |
| 40 | + (seed !== nothing) && seed!(rng, seed) |
| 41 | + |
| 42 | + # Extract MassActionJumps |
| 43 | + ma_jumps = jump_prob.massaction_jump |
| 44 | + num_jumps = get_num_majumps(ma_jumps) |
| 45 | + |
| 46 | + # Pre-allocate |
| 47 | + u0 = copy(prob.u0) |
| 48 | + u_work = similar(u0) # Working state vector |
| 49 | + |
| 50 | + tspan = prob.tspan |
| 51 | + p = prob.p |
| 52 | + n = Int((tspan[2] - tspan[1]) / dt) + 1 |
| 53 | + u = Vector{typeof(u0)}(undef, n) |
| 54 | + u[1] = u0 |
| 55 | + t = tspan[1]:dt:tspan[2] |
| 56 | + |
| 57 | + # Main loop - operator splitting with individual jump execution |
| 58 | + for i in 2:n |
| 59 | + copy!(u_work, u[i-1]) |
| 60 | + |
| 61 | + # Split tau-leaping: evaluate and execute each jump separately |
| 62 | + @inbounds for j in 1:num_jumps |
| 63 | + rate = evalrxrate(u_work, j, ma_jumps) |
| 64 | + num_firings = pois_rand(rng, rate * dt) |
| 65 | + |
| 66 | + # Execute this jump num_firings times |
| 67 | + for _ in 1:num_firings |
| 68 | + executerx!(u_work, j, ma_jumps) |
| 69 | + end |
| 70 | + end |
| 71 | + |
| 72 | + u[i] = copy(u_work) |
| 73 | + end |
| 74 | + |
| 75 | + sol = DiffEqBase.build_solution(prob, alg, t, u, |
| 76 | + calculate_error = false, |
| 77 | + interp = DiffEqBase.ConstantInterpolation(t, u)) |
| 78 | +end |
| 79 | +``` |
| 80 | + |
| 81 | +## 4. Test Implementation |
| 82 | +Create comprehensive tests with 5% accuracy target: |
| 83 | + |
| 84 | +```julia |
| 85 | +@testset "SimpleSplitTauLeaping MassActionJump Tests" begin |
| 86 | + using Statistics, JumpProcesses, DiffEqBase, Random |
| 87 | + |
| 88 | + # Test 1: Birth-death process via MassActionJumps |
| 89 | + @testset "Birth-Death MassActionJump" begin |
| 90 | + # ∅ → X (birth), X → ∅ (death) |
| 91 | + reactant_stoich = [Vector{Pair{Int,Int}}(), # ∅ → X |
| 92 | + [1 => 1]] # X → ∅ |
| 93 | + net_stoich = [[1 => 1], # ∅ → X adds one |
| 94 | + [1 => -1]] # X → ∅ removes one |
| 95 | + rates = [1.0, 0.1] # birth rate, death rate |
| 96 | + |
| 97 | + ma_jumps = MassActionJump(rates, reactant_stoich, net_stoich) |
| 98 | + |
| 99 | + u0 = [10] |
| 100 | + tspan = (0.0, 100.0) |
| 101 | + dprob = DiscreteProblem(u0, tspan) |
| 102 | + |
| 103 | + # Run ensembles for statistics |
| 104 | + n_traj = 10000 |
| 105 | + |
| 106 | + # Direct method reference |
| 107 | + jprob_direct = JumpProblem(dprob, Direct(), ma_jumps) |
| 108 | + ensembleprob_direct = EnsembleProblem(jprob_direct) |
| 109 | + sol_direct = solve(ensembleprob_direct, SSAStepper(), |
| 110 | + EnsembleThreads(), trajectories=n_traj) |
| 111 | + |
| 112 | + # SimpleSplitTauLeaping with small dt |
| 113 | + jprob_split = JumpProblem(dprob, PureLeaping(), ma_jumps) |
| 114 | + |
| 115 | + function run_split_ensemble(prob, n_traj, dt, seed) |
| 116 | + sols = Vector{Any}(undef, n_traj) |
| 117 | + for i in 1:n_traj |
| 118 | + sols[i] = solve(prob, SimpleSplitTauLeaping(), |
| 119 | + dt=dt, seed=seed+i) |
| 120 | + end |
| 121 | + return sols |
| 122 | + end |
| 123 | + |
| 124 | + sol_split = run_split_ensemble(jprob_split, n_traj, 0.001, 12345) |
| 125 | + |
| 126 | + # Extract final values |
| 127 | + direct_final = [sol.u[end][1] for sol in sol_direct] |
| 128 | + split_final = [sol.u[end][1] for sol in sol_split] |
| 129 | + |
| 130 | + # Test mean and variance (5% relative accuracy) |
| 131 | + @test mean(split_final) ≈ mean(direct_final) rtol=0.05 |
| 132 | + @test var(split_final) ≈ var(direct_final) rtol=0.05 |
| 133 | + end |
| 134 | + |
| 135 | + # Test 2: Simple reaction A + B → C |
| 136 | + @testset "A + B → C Reaction" begin |
| 137 | + reactant_stoich = [[1 => 1, 2 => 1]] # A + B |
| 138 | + net_stoich = [[1 => -1, 2 => -1, 3 => 1]] # -A -B +C |
| 139 | + rates = [0.001] |
| 140 | + |
| 141 | + ma_jumps = MassActionJump(rates, reactant_stoich, net_stoich) |
| 142 | + |
| 143 | + u0 = [100, 100, 0] |
| 144 | + tspan = (0.0, 10.0) |
| 145 | + dprob = DiscreteProblem(u0, tspan) |
| 146 | + |
| 147 | + n_traj = 5000 |
| 148 | + |
| 149 | + # Direct reference |
| 150 | + jprob_direct = JumpProblem(dprob, Direct(), ma_jumps) |
| 151 | + ensembleprob_direct = EnsembleProblem(jprob_direct) |
| 152 | + sol_direct = solve(ensembleprob_direct, SSAStepper(), |
| 153 | + EnsembleThreads(), trajectories=n_traj) |
| 154 | + |
| 155 | + # SimpleSplitTauLeaping |
| 156 | + jprob_split = JumpProblem(dprob, PureLeaping(), ma_jumps) |
| 157 | + sol_split = run_split_ensemble(jprob_split, n_traj, 0.0001, 54321) |
| 158 | + |
| 159 | + # Compare means of all species at final time |
| 160 | + for species in 1:3 |
| 161 | + direct_vals = [sol.u[end][species] for sol in sol_direct] |
| 162 | + split_vals = [sol.u[end][species] for sol in sol_split] |
| 163 | + |
| 164 | + @test mean(split_vals) ≈ mean(direct_vals) rtol=0.05 |
| 165 | + @test var(split_vals) ≈ var(direct_vals) rtol=0.05 |
| 166 | + end |
| 167 | + |
| 168 | + # Check conservation |
| 169 | + for sol in sol_split |
| 170 | + @test sum(sol.u[end]) == sum(u0) |
| 171 | + end |
| 172 | + end |
| 173 | + |
| 174 | + # Test 3: Lotka-Volterra predator-prey |
| 175 | + @testset "Lotka-Volterra System" begin |
| 176 | + # X → 2X (prey birth) |
| 177 | + # X + Y → 2Y (predation) |
| 178 | + # Y → ∅ (predator death) |
| 179 | + reactant_stoich = [[1 => 1], # X |
| 180 | + [1 => 1, 2 => 1], # X + Y |
| 181 | + [2 => 1]] # Y |
| 182 | + net_stoich = [[1 => 1], # X births |
| 183 | + [1 => -1, 2 => 1], # X dies, Y births |
| 184 | + [2 => -1]] # Y dies |
| 185 | + rates = [1.0, 0.001, 1.0] |
| 186 | + |
| 187 | + ma_jumps = MassActionJump(rates, reactant_stoich, net_stoich) |
| 188 | + |
| 189 | + u0 = [100, 100] # Initial prey and predator |
| 190 | + tspan = (0.0, 20.0) |
| 191 | + dprob = DiscreteProblem(u0, tspan) |
| 192 | + |
| 193 | + n_traj = 5000 |
| 194 | + |
| 195 | + # Direct |
| 196 | + jprob_direct = JumpProblem(dprob, Direct(), ma_jumps) |
| 197 | + ensembleprob_direct = EnsembleProblem(jprob_direct) |
| 198 | + sol_direct = solve(ensembleprob_direct, SSAStepper(), |
| 199 | + EnsembleThreads(), trajectories=n_traj) |
| 200 | + |
| 201 | + # SimpleSplitTauLeaping with very small dt for accuracy |
| 202 | + jprob_split = JumpProblem(dprob, PureLeaping(), ma_jumps) |
| 203 | + sol_split = run_split_ensemble(jprob_split, n_traj, 0.0001, 99999) |
| 204 | + |
| 205 | + # Sample at multiple time points |
| 206 | + test_times = [5.0, 10.0, 15.0, 20.0] |
| 207 | + for test_t in test_times |
| 208 | + # Find closest time index |
| 209 | + t_idx_direct = findfirst(t -> t >= test_t, sol_direct[1].t) |
| 210 | + t_idx_split = findfirst(t -> t >= test_t, sol_split[1].t) |
| 211 | + |
| 212 | + for species in 1:2 |
| 213 | + direct_vals = [sol.u[t_idx_direct][species] for sol in sol_direct] |
| 214 | + split_vals = [sol.u[t_idx_split][species] for sol in sol_split] |
| 215 | + |
| 216 | + @test mean(split_vals) ≈ mean(direct_vals) rtol=0.05 |
| 217 | + end |
| 218 | + end |
| 219 | + end |
| 220 | +end |
| 221 | +``` |
| 222 | + |
| 223 | +## 5. Key Simplifications |
| 224 | + |
| 225 | +- **MassActionJumps only**: No need for FunctionWrappers or type dispatch |
| 226 | +- **Direct rate evaluation**: Use existing `evalrxrate` and `executerx!` |
| 227 | +- **Minimal allocations**: Only `u_work` vector for in-place updates |
| 228 | +- **Simple validation**: Check only for MassActionJumps presence |
| 229 | + |
| 230 | +## 6. Performance Notes |
| 231 | + |
| 232 | +- Type stable since only one jump type |
| 233 | +- Cache-friendly sequential access |
| 234 | +- Minimal branching in inner loop |
| 235 | +- In-place operations throughout |
| 236 | + |
| 237 | +## 7. Implementation Strategy |
| 238 | + |
| 239 | +1. Add the struct and solve method to `src/simple_regular_solve.jl` |
| 240 | +2. Add export in `src/JumpProcesses.jl` |
| 241 | +3. Create test file or add to existing test suite |
| 242 | +4. Verify 5% accuracy requirement is met across different models |
| 243 | +5. Document the method's operator splitting approach |
| 244 | + |
| 245 | +## 8. Future Extensions |
| 246 | + |
| 247 | +Once the basic MassActionJump implementation is working: |
| 248 | +- Add support for ConstantRateJumps using FunctionWrappers |
| 249 | +- Add support for VariableRateJumps |
| 250 | +- Consider adaptive time-stepping |
| 251 | +- Optimize for specific reaction network structures |
0 commit comments