Skip to content

Commit 2df7477

Browse files
using maj for adaptive tauleaping
1 parent 38a5c3d commit 2df7477

File tree

2 files changed

+161
-52
lines changed

2 files changed

+161
-52
lines changed

src/simple_regular_solve.jl

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ end
6868

6969
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
7070

71+
function compute_hor(nu)
72+
hor = zeros(Int, size(nu, 2))
73+
for j in 1:size(nu, 2)
74+
hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1
75+
end
76+
return hor
77+
end
78+
7179
function compute_gi(u, nu, hor, i)
7280
max_order = 1.0
7381
for j in 1:size(nu, 2)
@@ -101,16 +109,21 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
101109
seed = nothing,
102110
dtmin = 1e-10,
103111
saveat = nothing)
104-
validate_pure_leaping_inputs(jump_prob, alg) ||
105-
error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only non-RegularJumps.")
112+
if jump_prob.massaction_jump === nothing
113+
error("SimpleAdaptiveTauLeaping requires a JumpProblem with a MassActionJump.")
114+
end
106115
prob = jump_prob.prob
107116
rng = DEFAULT_RNG
108117
(seed !== nothing) && seed!(rng, seed)
109118

110-
rj = jump_prob.regular_jump
111-
rate = rj.rate
112-
numjumps = rj.numjumps
113-
c = rj.c
119+
maj = jump_prob.massaction_jump
120+
numjumps = get_num_majumps(maj)
121+
# Extract rates
122+
rate = (out, u, p, t) -> begin
123+
for j in 1:get_num_majumps(maj)
124+
out[j] = evalrxrate(u, j, maj)
125+
end
126+
end
114127
u0 = copy(prob.u0)
115128
tspan = prob.tspan
116129
p = prob.p
@@ -123,33 +136,21 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
123136
t_end = tspan[2]
124137
epsilon = alg.epsilon
125138

126-
# Compute initial stoichiometry and HOR
139+
# Extract stoichiometry once from MassActionJump
127140
nu = zeros(Int, length(u0), numjumps)
128-
counts_temp = zeros(Int, numjumps)
129141
for j in 1:numjumps
130-
fill!(counts_temp, 0)
131-
counts_temp[j] = 1
132-
c(du, u0, p, t[1], counts_temp, nothing)
133-
nu[:, j] = du
134-
end
135-
hor = zeros(Int, size(nu, 2))
136-
for j in 1:size(nu, 2)
137-
hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1
142+
for (spec_idx, stoich) in maj.net_stoch[j]
143+
nu[spec_idx, j] = stoich
144+
end
138145
end
146+
hor = compute_hor(nu)
139147

140148
saveat_times = isnothing(saveat) ? Vector{typeof(t)}() : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat)
141149
save_idx = 1
142150

143151
while t[end] < t_end
144152
u_prev = u[end]
145153
t_prev = t[end]
146-
# Recompute stoichiometry
147-
for j in 1:numjumps
148-
fill!(counts_temp, 0)
149-
counts_temp[j] = 1
150-
c(du, u_prev, p, t_prev, counts_temp, nothing)
151-
nu[:, j] = du
152-
end
153154
rate(rate_cache, u_prev, p, t_prev)
154155
tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
155156
tau = min(tau, t_end - t_prev)
@@ -159,7 +160,12 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
159160
end
160161
end
161162
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
162-
c(du, u_prev, p, t_prev, counts, nothing)
163+
du .= 0
164+
for j in 1:numjumps
165+
for (spec_idx, stoich) in maj.net_stoch[j]
166+
du[spec_idx] += stoich * counts[j]
167+
end
168+
end
163169
u_new = u_prev + du
164170
if any(<(0), u_new)
165171
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)

test/regular_jumps.jl

Lines changed: 131 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Nsims = 1000
1515
# ConstantRateJump formulation for SSAStepper
1616
rate1(u, p, t) = p[1] * u[1] * u[2] # β*S*I (infection)
1717
rate2(u, p, t) = p[2] * u[2] # ν*I (recovery)
18-
rate3(u, p, t) = p[3] # influx_rate
18+
rate3(u, p, t) = p[3] # influx_rate (S influx)
1919
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
2020
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
2121
affect3!(integrator) = (integrator.u[1] += 1; nothing)
@@ -24,41 +24,49 @@ Nsims = 1000
2424
u0 = [999.0, 10.0, 0.0] # S, I, R
2525
tspan = (0.0, 250.0)
2626
prob_disc = DiscreteProblem(u0, tspan, p)
27-
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng)
27+
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=rng)
2828

2929
# Solve with SSAStepper
30-
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
30+
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
3131

32-
# RegularJump formulation for TauLeaping methods
32+
# RegularJump formulation for SimpleTauLeaping
3333
regular_rate = (out, u, p, t) -> begin
3434
out[1] = p[1] * u[1] * u[2]
3535
out[2] = p[2] * u[2]
3636
out[3] = p[3]
3737
end
3838
regular_c = (dc, u, p, t, counts, mark) -> begin
39-
dc .= 0.0
39+
dc .= 0
4040
dc[1] = -counts[1] + counts[3]
4141
dc[2] = counts[1] - counts[2]
4242
dc[3] = counts[2]
4343
end
4444
rj = RegularJump(regular_rate, regular_c, 3)
45-
jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng = rng)
45+
jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng=rng)
4646

47-
# Solve with SimpleTauLeaping (dt=0.1)
47+
# Solve with SimpleTauLeaping
4848
sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1)
49-
49+
50+
# MassActionJump formulation for SimpleAdaptiveTauLeaping
51+
reactant_stoich = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]]
52+
net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [1=>1]]
53+
param_idxs = [1, 2, 3]
54+
maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs)
55+
jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng)
56+
5057
# Solve with SimpleAdaptiveTauLeaping
51-
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0)
58+
sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
5259

53-
# Compute mean trajectories at t = 0, 1, ..., 250
60+
# Compute mean infected (I) trajectories
5461
t_points = 0:1.0:250.0
55-
mean_direct_S = [mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points]
56-
mean_simple_S = [mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_points]
57-
mean_adaptive_S = [mean(sol_adaptive[i](t)[2] for i in 1:Nsims) for t in t_points]
62+
mean_direct_I = [mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points]
63+
mean_simple_I = [mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_points]
64+
mean_adaptive_I = [mean(sol_adaptive[i](t)[2] for i in 1:Nsims) for t in t_points]
5865

66+
# Test mean infected trajectories
5967
for i in 1:251
60-
@test isapprox(mean_direct_S[i], mean_simple_S[i], rtol=0.10)
61-
@test isapprox(mean_direct_S[i], mean_adaptive_S[i], rtol=0.10)
68+
@test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.10)
69+
@test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.10)
6270
end
6371
end
6472

@@ -81,12 +89,12 @@ end
8189
u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R
8290
tspan = (0.0, 250.0)
8391
prob_disc = DiscreteProblem(u0, tspan, p)
84-
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng)
92+
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=rng)
8593

8694
# Solve with SSAStepper
87-
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
95+
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
8896

89-
# RegularJump formulation for TauLeaping methods
97+
# RegularJump formulation for SimpleTauLeaping
9098
regular_rate = (out, u, p, t) -> begin
9199
out[1] = p[1] * u[1] * u[3]
92100
out[2] = p[2] * u[2]
@@ -100,22 +108,117 @@ end
100108
dc[4] = counts[3]
101109
end
102110
rj = RegularJump(regular_rate, regular_c, 3)
103-
jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng = rng)
111+
jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng=rng)
104112

105-
# Solve with SimpleTauLeaping (dt=0.1)
113+
# Solve with SimpleTauLeaping
106114
sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1)
107-
115+
116+
# MassActionJump formulation for SimpleAdaptiveTauLeaping
117+
reactant_stoich = [[1=>1, 3=>1], [2=>1], [3=>1]]
118+
net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [3=>-1, 4=>1]]
119+
param_idxs = [1, 2, 3]
120+
maj = MassActionJump(reactant_stoich, net_stoich; param_idxs=param_idxs)
121+
jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng=rng)
122+
108123
# Solve with SimpleAdaptiveTauLeaping
109-
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0)
124+
sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
110125

111-
# Compute mean trajectories at t = 0, 1, ..., 250
126+
# Compute mean infected (I) trajectories
112127
t_points = 0:1.0:250.0
113-
mean_direct_S = [mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]
114-
mean_simple_S = [mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points]
115-
mean_adaptive_S = [mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points]
128+
mean_direct_I = [mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]
129+
mean_simple_I = [mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points]
130+
mean_adaptive_I = [mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points]
116131

132+
# Test mean infected trajectories
117133
for i in 1:251
118-
@test isapprox(mean_direct_S[i], mean_simple_S[i], rtol=0.10)
119-
@test isapprox(mean_direct_S[i], mean_adaptive_S[i], rtol=0.10)
134+
@test isapprox(mean_direct_I[i], mean_simple_I[i], rtol=0.10)
135+
@test isapprox(mean_direct_I[i], mean_adaptive_I[i], rtol=0.10)
136+
end
137+
end
138+
139+
# Test PureLeaping aggregator functionality
140+
@testset "PureLeaping Aggregator Tests" begin
141+
# Test with MassActionJump
142+
u0 = [10, 5, 0]
143+
tspan = (0.0, 10.0)
144+
p = [0.1, 0.2]
145+
prob = DiscreteProblem(u0, tspan, p)
146+
147+
# Create MassActionJump
148+
reactant_stoich = [[1 => 1], [1 => 2]]
149+
net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]]
150+
rates = [0.1, 0.05]
151+
maj = MassActionJump(rates, reactant_stoich, net_stoich)
152+
153+
# Test PureLeaping JumpProblem creation
154+
jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj))
155+
@test jp_pure.aggregator isa PureLeaping
156+
@test jp_pure.discrete_jump_aggregation === nothing
157+
@test jp_pure.massaction_jump !== nothing
158+
@test length(jp_pure.jump_callback.discrete_callbacks) == 0
159+
160+
# Test with ConstantRateJump
161+
rate(u, p, t) = p[1] * u[1]
162+
affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
163+
crj = ConstantRateJump(rate, affect!)
164+
165+
jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj))
166+
@test jp_pure_crj.aggregator isa PureLeaping
167+
@test jp_pure_crj.discrete_jump_aggregation === nothing
168+
@test length(jp_pure_crj.constant_jumps) == 1
169+
170+
# Test with VariableRateJump
171+
vrate(u, p, t) = t * p[1] * u[1]
172+
vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
173+
vrj = VariableRateJump(vrate, vaffect!)
174+
175+
jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj))
176+
@test jp_pure_vrj.aggregator isa PureLeaping
177+
@test jp_pure_vrj.discrete_jump_aggregation === nothing
178+
@test length(jp_pure_vrj.variable_jumps) == 1
179+
180+
# Test with RegularJump
181+
function rj_rate(out, u, p, t)
182+
out[1] = p[1] * u[1]
183+
end
184+
185+
rj_dc = zeros(3, 1)
186+
rj_dc[1, 1] = -1
187+
rj_dc[3, 1] = 1
188+
189+
function rj_c(du, u, p, t, counts, mark)
190+
mul!(du, rj_dc, counts)
120191
end
192+
193+
regj = RegularJump(rj_rate, rj_c, 1)
194+
195+
jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj))
196+
@test jp_pure_regj.aggregator isa PureLeaping
197+
@test jp_pure_regj.discrete_jump_aggregation === nothing
198+
@test jp_pure_regj.regular_jump !== nothing
199+
200+
# Test mixed jump types
201+
mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,),
202+
variable_jumps = (vrj,), regular_jumps = regj)
203+
jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps)
204+
@test jp_pure_mixed.aggregator isa PureLeaping
205+
@test jp_pure_mixed.discrete_jump_aggregation === nothing
206+
@test jp_pure_mixed.massaction_jump !== nothing
207+
@test length(jp_pure_mixed.constant_jumps) == 1
208+
@test length(jp_pure_mixed.variable_jumps) == 1
209+
@test jp_pure_mixed.regular_jump !== nothing
210+
211+
# Test spatial system error
212+
spatial_sys = CartesianGrid((2, 2))
213+
hopping_consts = [1.0]
214+
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj);
215+
spatial_system = spatial_sys)
216+
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj);
217+
hopping_constants = hopping_consts)
218+
219+
# Test MassActionJump with parameter mapping
220+
maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2])
221+
jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params))
222+
scaled_rates = [p[1], p[2]/2]
223+
@test jp_params.massaction_jump.scaled_rates == scaled_rates
121224
end

0 commit comments

Comments
 (0)