Skip to content

Commit e5455b6

Browse files
test refactor
1 parent bd4a452 commit e5455b6

File tree

1 file changed

+87
-48
lines changed

1 file changed

+87
-48
lines changed

test/regular_jumps.jl

Lines changed: 87 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,77 +6,116 @@ rng = StableRNG(12345)
66
Nsims = 8000
77

88
# SIR model with influx
9-
let
9+
@testset "SIR Model Correctness" begin
1010
β = 0.1 / 1000.0
1111
ν = 0.01
1212
influx_rate = 1.0
1313
p = (β, ν, influx_rate)
1414

15-
regular_rate = (out, u, p, t) -> begin
16-
out[1] = p[1] * u[1] * u[2] # β*S*I (infection)
17-
out[2] = p[2] * u[2] # ν*I (recovery)
18-
out[3] = p[3] # influx_rate
19-
end
20-
21-
regular_c = (dc, u, p, t, counts, mark) -> begin
22-
dc .= 0.0
23-
dc[1] = -counts[1] + counts[3] # S: -infection + influx
24-
dc[2] = counts[1] - counts[2] # I: +infection - recovery
25-
dc[3] = counts[2] # R: +recovery
26-
end
15+
# ConstantRateJump formulation for SSAStepper
16+
rate1(u, p, t) = p[1] * u[1] * u[2] # β*S*I (infection)
17+
rate2(u, p, t) = p[2] * u[2] # ν*I (recovery)
18+
rate3(u, p, t) = p[3] # influx_rate
19+
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
20+
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
21+
affect3!(integrator) = (integrator.u[1] += 1; nothing)
22+
jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!))
2723

2824
u0 = [999.0, 10.0, 0.0] # S, I, R
2925
tspan = (0.0, 250.0)
30-
3126
prob_disc = DiscreteProblem(u0, tspan, p)
32-
rj = RegularJump(regular_rate, regular_c, 3)
33-
jump_prob = JumpProblem(prob_disc, Direct(), rj)
34-
35-
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
36-
mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims)
27+
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng)
3728

38-
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
39-
mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims)
29+
# Solve with SSAStepper
30+
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
4031

41-
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
32+
# RegularJump formulation for TauLeaping methods
33+
regular_rate = (out, u, p, t) -> begin
34+
out[1] = p[1] * u[1] * u[2]
35+
out[2] = p[2] * u[2]
36+
out[3] = p[3]
37+
end
38+
regular_c = (dc, u, p, t, counts, mark) -> begin
39+
dc .= 0.0
40+
dc[1] = -counts[1] + counts[3]
41+
dc[2] = counts[1] - counts[2]
42+
dc[3] = counts[2]
43+
end
44+
rj = RegularJump(regular_rate, regular_c, 3)
45+
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng = rng)
46+
47+
# Solve with SimpleTauLeaping (dt=0.1)
48+
sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1)
49+
50+
# Solve with SimpleAdaptiveTauLeaping
51+
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims)
52+
53+
# Compute mean trajectories at t = 0, 1, ..., 250
54+
t_points = 0:1.0:250.0
55+
mean_direct_S = [mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points]
56+
mean_simple_S = [mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_points]
57+
mean_adaptive_S = [mean(sol_adaptive[i](t)[2] for i in 1:Nsims) for t in t_points]
58+
59+
for i in 1:251
60+
@test isapprox(mean_direct_S[i], mean_simple_S[i], rtol=0.10)
61+
@test isapprox(mean_direct_S[i], mean_adaptive_S[i], rtol=0.10)
62+
end
4263
end
4364

44-
4565
# SEIR model with exposed compartment
46-
let
66+
@testset "SEIR Model Correctness" begin
4767
β = 0.3 / 1000.0
4868
σ = 0.2
4969
ν = 0.01
5070
p = (β, σ, ν)
5171

52-
regular_rate = (out, u, p, t) -> begin
53-
out[1] = p[1] * u[1] * u[3] # β*S*I (infection)
54-
out[2] = p[2] * u[2] # σ*E (progression)
55-
out[3] = p[3] * u[3] # ν*I (recovery)
56-
end
57-
58-
regular_c = (dc, u, p, t, counts, mark) -> begin
59-
dc .= 0.0
60-
dc[1] = -counts[1] # S: -infection
61-
dc[2] = counts[1] - counts[2] # E: +infection - progression
62-
dc[3] = counts[2] - counts[3] # I: +progression - recovery
63-
dc[4] = counts[3] # R: +recovery
64-
end
72+
# ConstantRateJump formulation for SSAStepper
73+
rate1(u, p, t) = p[1] * u[1] * u[3] # β*S*I (infection)
74+
rate2(u, p, t) = p[2] * u[2] # σ*E (progression)
75+
rate3(u, p, t) = p[3] * u[3] # ν*I (recovery)
76+
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
77+
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
78+
affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing)
79+
jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!))
6580

66-
# Initial state
6781
u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R
6882
tspan = (0.0, 250.0)
69-
70-
# Create JumpProblem
7183
prob_disc = DiscreteProblem(u0, tspan, p)
72-
rj = RegularJump(regular_rate, regular_c, 3)
73-
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
84+
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng)
7485

75-
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
76-
mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims)
86+
# Solve with SSAStepper
87+
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
7788

78-
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
79-
mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims)
80-
81-
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
89+
# RegularJump formulation for TauLeaping methods
90+
regular_rate = (out, u, p, t) -> begin
91+
out[1] = p[1] * u[1] * u[3]
92+
out[2] = p[2] * u[2]
93+
out[3] = p[3] * u[3]
94+
end
95+
regular_c = (dc, u, p, t, counts, mark) -> begin
96+
dc .= 0.0
97+
dc[1] = -counts[1]
98+
dc[2] = counts[1] - counts[2]
99+
dc[3] = counts[2] - counts[3]
100+
dc[4] = counts[3]
101+
end
102+
rj = RegularJump(regular_rate, regular_c, 3)
103+
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng = rng)
104+
105+
# Solve with SimpleTauLeaping (dt=0.1)
106+
sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1)
107+
108+
# Solve with SimpleAdaptiveTauLeaping
109+
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims)
110+
111+
# Compute mean trajectories at t = 0, 1, ..., 250
112+
t_points = 0:1.0:250.0
113+
mean_direct_S = [mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]
114+
mean_simple_S = [mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_points]
115+
mean_adaptive_S = [mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_points]
116+
117+
for i in 1:251
118+
@test isapprox(mean_direct_S[i], mean_simple_S[i], rtol=0.10)
119+
@test isapprox(mean_direct_S[i], mean_adaptive_S[i], rtol=0.10)
120+
end
82121
end

0 commit comments

Comments
 (0)