Skip to content

Commit 467a02f

Browse files
removed adaptive tau leap
1 parent 8602155 commit 467a02f

File tree

2 files changed

+62
-109
lines changed

2 files changed

+62
-109
lines changed

src/simple_regular_solve.jl

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -50,60 +50,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
5050
interp = DiffEqBase.ConstantInterpolation(t, u))
5151
end
5252

53-
struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm
54-
epsilon::Float64 # Error control parameter
55-
end
56-
57-
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
58-
59-
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing)
60-
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
61-
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
62-
prob = jump_prob.prob
63-
rng = DEFAULT_RNG
64-
(seed !== nothing) && seed!(rng, seed)
65-
66-
rj = jump_prob.regular_jump
67-
rate = rj.rate
68-
numjumps = rj.numjumps
69-
c = rj.c
70-
u0 = copy(prob.u0)
71-
tspan = prob.tspan
72-
p = prob.p
73-
74-
u = [copy(u0)]
75-
t = [tspan[1]]
76-
rate_cache = zeros(Float64, numjumps)
77-
counts = zeros(Int, numjumps)
78-
du = similar(u0)
79-
t_end = tspan[2]
80-
epsilon = alg.epsilon
81-
82-
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
83-
84-
while t[end] < t_end
85-
u_prev = u[end]
86-
t_prev = t[end]
87-
rate(rate_cache, u_prev, p, t_prev)
88-
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
89-
tau = min(tau, t_end - t_prev)
90-
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
91-
c(du, u_prev, p, t_prev, counts, nothing)
92-
u_new = u_prev + du
93-
if any(u_new .< 0)
94-
tau /= 2
95-
continue
96-
end
97-
push!(u, u_new)
98-
push!(t, t_prev + tau)
99-
end
100-
101-
sol = DiffEqBase.build_solution(prob, alg, t, u,
102-
calculate_error=false,
103-
interp=DiffEqBase.ConstantInterpolation(t, u))
104-
return sol
105-
end
106-
10753
# SimpleImplicitTauLeaping implementation
10854
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
10955
epsilon::Float64 # Error control parameter
@@ -306,4 +252,4 @@ function EnsembleGPUKernel()
306252
EnsembleGPUKernel(nothing, 0.0)
307253
end
308254

309-
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping
255+
export SimpleTauLeaping, EnsembleGPUKernel, SimpleImplicitTauLeaping

test/regular_jumps.jl

Lines changed: 61 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,86 +5,93 @@ rng = StableRNG(12345)
55

66
Nsims = 8000
77

8-
# SIR model with influx
9-
let
8+
@testset "SIR Model Correctness" begin
109
β = 0.1 / 1000.0
1110
ν = 0.01
1211
influx_rate = 1.0
1312
p = (β, ν, influx_rate)
1413

14+
rate1(u, p, t) = p[1] * u[1] * u[2]
15+
rate2(u, p, t) = p[2] * u[2]
16+
rate3(u, p, t) = p[3]
17+
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
18+
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
19+
affect3!(integrator) = (integrator.u[1] += 1; nothing)
20+
jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!))
21+
22+
u0 = [999, 10, 0] # Integer initial conditions
23+
tspan = (0.0, 250.0)
24+
prob_disc = DiscreteProblem(u0, tspan, p)
25+
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345))
26+
27+
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
28+
1529
regular_rate = (out, u, p, t) -> begin
16-
out[1] = p[1] * u[1] * u[2] # β*S*I (infection)
17-
out[2] = p[2] * u[2] # ν*I (recovery)
18-
out[3] = p[3] # influx_rate
30+
out[1] = p[1] * u[1] * u[2]
31+
out[2] = p[2] * u[2]
32+
out[3] = p[3]
1933
end
20-
2134
regular_c = (dc, u, p, t, counts, mark) -> begin
22-
dc .= 0.0
23-
dc[1] = -counts[1] + counts[3] # S: -infection + influx
24-
dc[2] = counts[1] - counts[2] # I: +infection - recovery
25-
dc[3] = counts[2] # R: +recovery
35+
dc .= 0
36+
dc[1] = -counts[1] + counts[3]
37+
dc[2] = counts[1] - counts[2]
38+
dc[3] = counts[2]
2639
end
27-
28-
u0 = [999.0, 10.0, 0.0] # S, I, R
29-
tspan = (0.0, 250.0)
30-
31-
prob_disc = DiscreteProblem(u0, tspan, p)
3240
rj = RegularJump(regular_rate, regular_c, 3)
33-
jump_prob = JumpProblem(prob_disc, Direct(), rj)
34-
35-
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
36-
mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims)
41+
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
3742

38-
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
39-
mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims)
43+
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
4044

41-
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
42-
mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims)
45+
t_points = 0:1.0:250.0
46+
mean_direct_S = [mean(sol_direct[i](t)[1] for i in 1:Nsims) for t in t_points]
47+
mean_implicit_S = [mean(sol_implicit[i](t)[1] for i in 1:Nsims) for t in t_points]
4348

44-
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
45-
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
49+
max_error_implicit = maximum(abs.(mean_direct_S .- mean_implicit_S))
50+
@test max_error_implicit < 0.01 * mean(mean_direct_S)
4651
end
4752

48-
49-
# SEIR model with exposed compartment
50-
let
53+
@testset "SEIR Model Correctness" begin
5154
β = 0.3 / 1000.0
5255
σ = 0.2
5356
ν = 0.01
5457
p = (β, σ, ν)
5558

59+
rate1(u, p, t) = p[1] * u[1] * u[3]
60+
rate2(u, p, t) = p[2] * u[2]
61+
rate3(u, p, t) = p[3] * u[3]
62+
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
63+
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
64+
affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing)
65+
jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!))
66+
67+
u0 = [999, 0, 10, 0] # Integer initial conditions
68+
tspan = (0.0, 250.0)
69+
prob_disc = DiscreteProblem(u0, tspan, p)
70+
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345))
71+
72+
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
73+
5674
regular_rate = (out, u, p, t) -> begin
57-
out[1] = p[1] * u[1] * u[3] # β*S*I (infection)
58-
out[2] = p[2] * u[2] # σ*E (progression)
59-
out[3] = p[3] * u[3] # ν*I (recovery)
75+
out[1] = p[1] * u[1] * u[3]
76+
out[2] = p[2] * u[2]
77+
out[3] = p[3] * u[3]
6078
end
61-
6279
regular_c = (dc, u, p, t, counts, mark) -> begin
63-
dc .= 0.0
64-
dc[1] = -counts[1] # S: -infection
65-
dc[2] = counts[1] - counts[2] # E: +infection - progression
66-
dc[3] = counts[2] - counts[3] # I: +progression - recovery
67-
dc[4] = counts[3] # R: +recovery
80+
dc .= 0
81+
dc[1] = -counts[1]
82+
dc[2] = counts[1] - counts[2]
83+
dc[3] = counts[2] - counts[3]
84+
dc[4] = counts[3]
6885
end
69-
70-
# Initial state
71-
u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R
72-
tspan = (0.0, 250.0)
73-
74-
# Create JumpProblem
75-
prob_disc = DiscreteProblem(u0, tspan, p)
7686
rj = RegularJump(regular_rate, regular_c, 3)
77-
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
78-
79-
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
80-
mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims)
87+
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
8188

82-
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
83-
mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims)
89+
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
8490

85-
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
86-
mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims)
91+
t_points = 0:1.0:250.0
92+
mean_direct_R = [mean(sol_direct[i](t)[4] for i in 1:Nsims) for t in t_points]
93+
mean_implicit_R = [mean(sol_implicit[i](t)[4] for i in 1:Nsims) for t in t_points]
8794

88-
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
89-
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
95+
max_error_implicit = maximum(abs.(mean_direct_R .- mean_implicit_R))
96+
@test max_error_implicit < 0.01 * mean(mean_direct_R)
9097
end

0 commit comments

Comments
 (0)