Skip to content

Commit b5226f7

Browse files
removed adaptive tau leap
1 parent 741431f commit b5226f7

File tree

2 files changed

+62
-101
lines changed

2 files changed

+62
-101
lines changed

src/simple_regular_solve.jl

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -61,60 +61,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
6161
interp = DiffEqBase.ConstantInterpolation(t, u))
6262
end
6363

64-
struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm
65-
epsilon::Float64 # Error control parameter
66-
end
67-
68-
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
69-
70-
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing)
71-
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
72-
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
73-
prob = jump_prob.prob
74-
rng = DEFAULT_RNG
75-
(seed !== nothing) && seed!(rng, seed)
76-
77-
rj = jump_prob.regular_jump
78-
rate = rj.rate
79-
numjumps = rj.numjumps
80-
c = rj.c
81-
u0 = copy(prob.u0)
82-
tspan = prob.tspan
83-
p = prob.p
84-
85-
u = [copy(u0)]
86-
t = [tspan[1]]
87-
rate_cache = zeros(Float64, numjumps)
88-
counts = zeros(Int, numjumps)
89-
du = similar(u0)
90-
t_end = tspan[2]
91-
epsilon = alg.epsilon
92-
93-
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
94-
95-
while t[end] < t_end
96-
u_prev = u[end]
97-
t_prev = t[end]
98-
rate(rate_cache, u_prev, p, t_prev)
99-
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
100-
tau = min(tau, t_end - t_prev)
101-
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
102-
c(du, u_prev, p, t_prev, counts, nothing)
103-
u_new = u_prev + du
104-
if any(u_new .< 0)
105-
tau /= 2
106-
continue
107-
end
108-
push!(u, u_new)
109-
push!(t, t_prev + tau)
110-
end
111-
112-
sol = DiffEqBase.build_solution(prob, alg, t, u,
113-
calculate_error=false,
114-
interp=DiffEqBase.ConstantInterpolation(t, u))
115-
return sol
116-
end
117-
11864
# SimpleImplicitTauLeaping implementation
11965
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
12066
epsilon::Float64 # Error control parameter
@@ -317,4 +263,4 @@ function EnsembleGPUKernel()
317263
EnsembleGPUKernel(nothing, 0.0)
318264
end
319265

320-
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping
266+
export SimpleTauLeaping, EnsembleGPUKernel, SimpleImplicitTauLeaping

test/regular_jumps.jl

Lines changed: 61 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,78 +5,93 @@ rng = StableRNG(12345)
55

66
Nsims = 8000
77

8-
# SIR model with influx
9-
let
8+
@testset "SIR Model Correctness" begin
109
β = 0.1 / 1000.0
1110
ν = 0.01
1211
influx_rate = 1.0
1312
p = (β, ν, influx_rate)
1413

14+
rate1(u, p, t) = p[1] * u[1] * u[2]
15+
rate2(u, p, t) = p[2] * u[2]
16+
rate3(u, p, t) = p[3]
17+
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
18+
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
19+
affect3!(integrator) = (integrator.u[1] += 1; nothing)
20+
jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!))
21+
22+
u0 = [999, 10, 0] # Integer initial conditions
23+
tspan = (0.0, 250.0)
24+
prob_disc = DiscreteProblem(u0, tspan, p)
25+
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345))
26+
27+
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
28+
1529
regular_rate = (out, u, p, t) -> begin
16-
out[1] = p[1] * u[1] * u[2] # β*S*I (infection)
17-
out[2] = p[2] * u[2] # ν*I (recovery)
18-
out[3] = p[3] # influx_rate
30+
out[1] = p[1] * u[1] * u[2]
31+
out[2] = p[2] * u[2]
32+
out[3] = p[3]
1933
end
20-
2134
regular_c = (dc, u, p, t, counts, mark) -> begin
22-
dc .= 0.0
23-
dc[1] = -counts[1] + counts[3] # S: -infection + influx
24-
dc[2] = counts[1] - counts[2] # I: +infection - recovery
25-
dc[3] = counts[2] # R: +recovery
35+
dc .= 0
36+
dc[1] = -counts[1] + counts[3]
37+
dc[2] = counts[1] - counts[2]
38+
dc[3] = counts[2]
2639
end
27-
28-
u0 = [999.0, 10.0, 0.0] # S, I, R
29-
tspan = (0.0, 250.0)
30-
31-
prob_disc = DiscreteProblem(u0, tspan, p)
3240
rj = RegularJump(regular_rate, regular_c, 3)
33-
jump_prob = JumpProblem(prob_disc, Direct(), rj)
41+
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
3442

35-
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
36-
mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims)
43+
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
3744

38-
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
39-
mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims)
45+
t_points = 0:1.0:250.0
46+
mean_direct_S = [mean(sol_direct[i](t)[1] for i in 1:Nsims) for t in t_points]
47+
mean_implicit_S = [mean(sol_implicit[i](t)[1] for i in 1:Nsims) for t in t_points]
4048

41-
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
49+
max_error_implicit = maximum(abs.(mean_direct_S .- mean_implicit_S))
50+
@test max_error_implicit < 0.01 * mean(mean_direct_S)
4251
end
4352

44-
45-
# SEIR model with exposed compartment
46-
let
53+
@testset "SEIR Model Correctness" begin
4754
β = 0.3 / 1000.0
4855
σ = 0.2
4956
ν = 0.01
5057
p = (β, σ, ν)
5158

59+
rate1(u, p, t) = p[1] * u[1] * u[3]
60+
rate2(u, p, t) = p[2] * u[2]
61+
rate3(u, p, t) = p[3] * u[3]
62+
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
63+
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
64+
affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing)
65+
jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!))
66+
67+
u0 = [999, 0, 10, 0] # Integer initial conditions
68+
tspan = (0.0, 250.0)
69+
prob_disc = DiscreteProblem(u0, tspan, p)
70+
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345))
71+
72+
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
73+
5274
regular_rate = (out, u, p, t) -> begin
53-
out[1] = p[1] * u[1] * u[3] # β*S*I (infection)
54-
out[2] = p[2] * u[2] # σ*E (progression)
55-
out[3] = p[3] * u[3] # ν*I (recovery)
75+
out[1] = p[1] * u[1] * u[3]
76+
out[2] = p[2] * u[2]
77+
out[3] = p[3] * u[3]
5678
end
57-
5879
regular_c = (dc, u, p, t, counts, mark) -> begin
59-
dc .= 0.0
60-
dc[1] = -counts[1] # S: -infection
61-
dc[2] = counts[1] - counts[2] # E: +infection - progression
62-
dc[3] = counts[2] - counts[3] # I: +progression - recovery
63-
dc[4] = counts[3] # R: +recovery
80+
dc .= 0
81+
dc[1] = -counts[1]
82+
dc[2] = counts[1] - counts[2]
83+
dc[3] = counts[2] - counts[3]
84+
dc[4] = counts[3]
6485
end
65-
66-
# Initial state
67-
u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R
68-
tspan = (0.0, 250.0)
69-
70-
# Create JumpProblem
71-
prob_disc = DiscreteProblem(u0, tspan, p)
7286
rj = RegularJump(regular_rate, regular_c, 3)
73-
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
87+
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
7488

75-
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
76-
mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims)
89+
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
7790

78-
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
79-
mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims)
91+
t_points = 0:1.0:250.0
92+
mean_direct_R = [mean(sol_direct[i](t)[4] for i in 1:Nsims) for t in t_points]
93+
mean_implicit_R = [mean(sol_implicit[i](t)[4] for i in 1:Nsims) for t in t_points]
8094

81-
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
95+
max_error_implicit = maximum(abs.(mean_direct_R .- mean_implicit_R))
96+
@test max_error_implicit < 0.01 * mean(mean_direct_R)
8297
end

0 commit comments

Comments
 (0)