Skip to content

Commit 8973805

Browse files
SimpleAdaptiveTauLeaping is done
1 parent 0eb637a commit 8973805

File tree

2 files changed

+125
-101
lines changed

2 files changed

+125
-101
lines changed

src/simple_regular_solve.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,60 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
6161
interp = DiffEqBase.ConstantInterpolation(t, u))
6262
end
6363

64+
struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm
65+
epsilon::Float64 # Error control parameter
66+
end
67+
68+
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
69+
70+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing)
71+
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
72+
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
73+
prob = jump_prob.prob
74+
rng = DEFAULT_RNG
75+
(seed !== nothing) && seed!(rng, seed)
76+
77+
rj = jump_prob.regular_jump
78+
rate = rj.rate
79+
numjumps = rj.numjumps
80+
c = rj.c
81+
u0 = copy(prob.u0)
82+
tspan = prob.tspan
83+
p = prob.p
84+
85+
u = [copy(u0)]
86+
t = [tspan[1]]
87+
rate_cache = zeros(Float64, numjumps)
88+
counts = zeros(Int, numjumps)
89+
du = similar(u0)
90+
t_end = tspan[2]
91+
epsilon = alg.epsilon
92+
93+
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
94+
95+
while t[end] < t_end
96+
u_prev = u[end]
97+
t_prev = t[end]
98+
rate(rate_cache, u_prev, p, t_prev)
99+
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
100+
tau = min(tau, t_end - t_prev)
101+
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
102+
c(du, u_prev, p, t_prev, counts, nothing)
103+
u_new = u_prev + du
104+
if any(u_new .< 0)
105+
tau /= 2
106+
continue
107+
end
108+
push!(u, u_new)
109+
push!(t, t_prev + tau)
110+
end
111+
112+
sol = DiffEqBase.build_solution(prob, alg, t, u,
113+
calculate_error=false,
114+
interp=DiffEqBase.ConstantInterpolation(t, u))
115+
return sol
116+
end
117+
64118
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
65119
epsilon::Float64 # Error control parameter
66120
nc::Int # Critical reaction threshold
@@ -390,4 +444,4 @@ function EnsembleGPUKernel()
390444
EnsembleGPUKernel(nothing, 0.0)
391445
end
392446

393-
export SimpleTauLeaping, EnsembleGPUKernel, SimpleImplicitTauLeaping
447+
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping

test/regular_jumps.jl

Lines changed: 70 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -3,110 +3,80 @@ using Test, LinearAlgebra, Statistics
33
using StableRNGs
44
rng = StableRNG(12345)
55

6-
function regular_rate(out, u, p, t)
7-
out[1] = (0.1 / 1000.0) * u[1] * u[2]
8-
out[2] = 0.01u[2]
9-
end
6+
Nsims = 8000
7+
8+
# SIR model with influx
9+
let
10+
β = 0.1 / 1000.0
11+
ν = 0.01
12+
influx_rate = 1.0
13+
p = (β, ν, influx_rate)
14+
15+
regular_rate = (out, u, p, t) -> begin
16+
out[1] = p[1] * u[1] * u[2] # β*S*I (infection)
17+
out[2] = p[2] * u[2] # ν*I (recovery)
18+
out[3] = p[3] # influx_rate
19+
end
20+
21+
regular_c = (dc, u, p, t, counts, mark) -> begin
22+
dc .= 0.0
23+
dc[1] = -counts[1] + counts[3] # S: -infection + influx
24+
dc[2] = counts[1] - counts[2] # I: +infection - recovery
25+
dc[3] = counts[2] # R: +recovery
26+
end
27+
28+
u0 = [999.0, 10.0, 0.0] # S, I, R
29+
tspan = (0.0, 250.0)
30+
31+
prob_disc = DiscreteProblem(u0, tspan, p)
32+
rj = RegularJump(regular_rate, regular_c, 3)
33+
jump_prob = JumpProblem(prob_disc, Direct(), rj)
34+
35+
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
36+
mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims)
1037

11-
const dc = zeros(3, 2)
12-
dc[1, 1] = -1
13-
dc[2, 1] = 1
14-
dc[2, 2] = -1
15-
dc[3, 2] = 1
38+
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
39+
mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims)
1640

17-
function regular_c(du, u, p, t, counts, mark)
18-
mul!(du, dc, counts)
41+
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
1942
end
2043

21-
rj = RegularJump(regular_rate, regular_c, 2)
22-
jumps = JumpSet(rj)
23-
prob = DiscreteProblem([999, 1, 0], (0.0, 250.0))
24-
jump_prob = JumpProblem(prob, PureLeaping(), rj; rng)
25-
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
26-
27-
# Test PureLeaping aggregator functionality
28-
@testset "PureLeaping Aggregator Tests" begin
29-
# Test with MassActionJump
30-
u0 = [10, 5, 0]
31-
tspan = (0.0, 10.0)
32-
p = [0.1, 0.2]
33-
prob = DiscreteProblem(u0, tspan, p)
34-
35-
# Create MassActionJump
36-
reactant_stoich = [[1 => 1], [1 => 2]]
37-
net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]]
38-
rates = [0.1, 0.05]
39-
maj = MassActionJump(rates, reactant_stoich, net_stoich)
40-
41-
# Test PureLeaping JumpProblem creation
42-
jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj); rng)
43-
@test jp_pure.aggregator isa PureLeaping
44-
@test jp_pure.discrete_jump_aggregation === nothing
45-
@test jp_pure.massaction_jump !== nothing
46-
@test length(jp_pure.jump_callback.discrete_callbacks) == 0
47-
48-
# Test with ConstantRateJump
49-
rate(u, p, t) = p[1] * u[1]
50-
affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
51-
crj = ConstantRateJump(rate, affect!)
52-
53-
jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj); rng)
54-
@test jp_pure_crj.aggregator isa PureLeaping
55-
@test jp_pure_crj.discrete_jump_aggregation === nothing
56-
@test length(jp_pure_crj.constant_jumps) == 1
57-
58-
# Test with VariableRateJump
59-
vrate(u, p, t) = t * p[1] * u[1]
60-
vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1)
61-
vrj = VariableRateJump(vrate, vaffect!)
62-
63-
jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj); rng)
64-
@test jp_pure_vrj.aggregator isa PureLeaping
65-
@test jp_pure_vrj.discrete_jump_aggregation === nothing
66-
@test length(jp_pure_vrj.variable_jumps) == 1
67-
68-
# Test with RegularJump
69-
function rj_rate(out, u, p, t)
70-
out[1] = p[1] * u[1]
44+
45+
# SEIR model with exposed compartment
46+
let
47+
β = 0.3 / 1000.0
48+
σ = 0.2
49+
ν = 0.01
50+
p = (β, σ, ν)
51+
52+
regular_rate = (out, u, p, t) -> begin
53+
out[1] = p[1] * u[1] * u[3] # β*S*I (infection)
54+
out[2] = p[2] * u[2] # σ*E (progression)
55+
out[3] = p[3] * u[3] # ν*I (recovery)
7156
end
72-
73-
rj_dc = zeros(3, 1)
74-
rj_dc[1, 1] = -1
75-
rj_dc[3, 1] = 1
76-
77-
function rj_c(du, u, p, t, counts, mark)
78-
mul!(du, rj_dc, counts)
57+
58+
regular_c = (dc, u, p, t, counts, mark) -> begin
59+
dc .= 0.0
60+
dc[1] = -counts[1] # S: -infection
61+
dc[2] = counts[1] - counts[2] # E: +infection - progression
62+
dc[3] = counts[2] - counts[3] # I: +progression - recovery
63+
dc[4] = counts[3] # R: +recovery
7964
end
80-
81-
regj = RegularJump(rj_rate, rj_c, 1)
82-
83-
jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj); rng)
84-
@test jp_pure_regj.aggregator isa PureLeaping
85-
@test jp_pure_regj.discrete_jump_aggregation === nothing
86-
@test jp_pure_regj.regular_jump !== nothing
87-
88-
# Test mixed jump types
89-
mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,),
90-
variable_jumps = (vrj,), regular_jumps = regj)
91-
jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps; rng)
92-
@test jp_pure_mixed.aggregator isa PureLeaping
93-
@test jp_pure_mixed.discrete_jump_aggregation === nothing
94-
@test jp_pure_mixed.massaction_jump !== nothing
95-
@test length(jp_pure_mixed.constant_jumps) == 1
96-
@test length(jp_pure_mixed.variable_jumps) == 1
97-
@test jp_pure_mixed.regular_jump !== nothing
98-
99-
# Test spatial system error
100-
spatial_sys = CartesianGrid((2, 2))
101-
hopping_consts = [1.0]
102-
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng,
103-
spatial_system = spatial_sys)
104-
@test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng,
105-
hopping_constants = hopping_consts)
106-
107-
# Test MassActionJump with parameter mapping
108-
maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2])
109-
jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng)
110-
scaled_rates = [p[1], p[2]/2]
111-
@test jp_params.massaction_jump.scaled_rates == scaled_rates
65+
66+
# Initial state
67+
u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R
68+
tspan = (0.0, 250.0)
69+
70+
# Create JumpProblem
71+
prob_disc = DiscreteProblem(u0, tspan, p)
72+
rj = RegularJump(regular_rate, regular_c, 3)
73+
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
74+
75+
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
76+
mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims)
77+
78+
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
79+
mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims)
80+
81+
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
11282
end

0 commit comments

Comments
 (0)