@@ -6,77 +6,116 @@ rng = StableRNG(12345)
66Nsims = 8000
77
88# SIR model with influx
9- let
9+ @testset " SIR Model Correctness " begin
1010 β = 0.1 / 1000.0
1111 ν = 0.01
1212 influx_rate = 1.0
1313 p = (β, ν, influx_rate)
1414
15- regular_rate = (out, u, p, t) -> begin
16- out[1 ] = p[1 ] * u[1 ] * u[2 ] # β*S*I (infection)
17- out[2 ] = p[2 ] * u[2 ] # ν*I (recovery)
18- out[3 ] = p[3 ] # influx_rate
19- end
20-
21- regular_c = (dc, u, p, t, counts, mark) -> begin
22- dc .= 0.0
23- dc[1 ] = - counts[1 ] + counts[3 ] # S: -infection + influx
24- dc[2 ] = counts[1 ] - counts[2 ] # I: +infection - recovery
25- dc[3 ] = counts[2 ] # R: +recovery
26- end
15+ # ConstantRateJump formulation for SSAStepper
16+ rate1 (u, p, t) = p[1 ] * u[1 ] * u[2 ] # β*S*I (infection)
17+ rate2 (u, p, t) = p[2 ] * u[2 ] # ν*I (recovery)
18+ rate3 (u, p, t) = p[3 ] # influx_rate
19+ affect1! (integrator) = (integrator. u[1 ] -= 1 ; integrator. u[2 ] += 1 ; nothing )
20+ affect2! (integrator) = (integrator. u[2 ] -= 1 ; integrator. u[3 ] += 1 ; nothing )
21+ affect3! (integrator) = (integrator. u[1 ] += 1 ; nothing )
22+ jumps = (ConstantRateJump (rate1, affect1!), ConstantRateJump (rate2, affect2!), ConstantRateJump (rate3, affect3!))
2723
2824 u0 = [999.0 , 10.0 , 0.0 ] # S, I, R
2925 tspan = (0.0 , 250.0 )
30-
3126 prob_disc = DiscreteProblem (u0, tspan, p)
32- rj = RegularJump (regular_rate, regular_c, 3 )
33- jump_prob = JumpProblem (prob_disc, Direct (), rj)
34-
35- sol = solve (EnsembleProblem (jump_prob), SimpleTauLeaping (), EnsembleSerial (); trajectories = Nsims, dt = 1.0 )
36- mean_simple = mean (sol. u[i][1 ,end ] for i in 1 : Nsims)
27+ jump_prob = JumpProblem (prob_disc, Direct (), jumps... ; rng = rng)
3728
38- sol = solve ( EnsembleProblem (jump_prob), SimpleAdaptiveTauLeaping (), EnsembleSerial (); trajectories = Nsims)
39- mean_adaptive = mean (sol . u[i][ 1 , end ] for i in 1 : Nsims)
29+ # Solve with SSAStepper
30+ sol_direct = solve ( EnsembleProblem (jump_prob), SSAStepper (), EnsembleSerial (); trajectories = Nsims)
4031
41- @test isapprox (mean_simple, mean_adaptive, rtol= 0.05 )
32+ # RegularJump formulation for TauLeaping methods
33+ regular_rate = (out, u, p, t) -> begin
34+ out[1 ] = p[1 ] * u[1 ] * u[2 ]
35+ out[2 ] = p[2 ] * u[2 ]
36+ out[3 ] = p[3 ]
37+ end
38+ regular_c = (dc, u, p, t, counts, mark) -> begin
39+ dc .= 0.0
40+ dc[1 ] = - counts[1 ] + counts[3 ]
41+ dc[2 ] = counts[1 ] - counts[2 ]
42+ dc[3 ] = counts[2 ]
43+ end
44+ rj = RegularJump (regular_rate, regular_c, 3 )
45+ jump_prob_tau = JumpProblem (prob_disc, Direct (), rj; rng = rng)
46+
47+ # Solve with SimpleTauLeaping (dt=0.1)
48+ sol_simple = solve (EnsembleProblem (jump_prob_tau), SimpleTauLeaping (), EnsembleSerial (); trajectories= Nsims, dt= 0.1 )
49+
50+ # Solve with SimpleAdaptiveTauLeaping
51+ sol_adaptive = solve (EnsembleProblem (jump_prob_tau), SimpleAdaptiveTauLeaping (), EnsembleSerial (); trajectories= Nsims)
52+
53+ # Compute mean trajectories at t = 0, 1, ..., 250
54+ t_points = 0 : 1.0 : 250.0
55+ mean_direct_S = [mean (sol_direct[i](t)[2 ] for i in 1 : Nsims) for t in t_points]
56+ mean_simple_S = [mean (sol_simple[i](t)[2 ] for i in 1 : Nsims) for t in t_points]
57+ mean_adaptive_S = [mean (sol_adaptive[i](t)[2 ] for i in 1 : Nsims) for t in t_points]
58+
59+ for i in 1 : 251
60+ @test isapprox (mean_direct_S[i], mean_simple_S[i], rtol= 0.10 )
61+ @test isapprox (mean_direct_S[i], mean_adaptive_S[i], rtol= 0.10 )
62+ end
4263end
4364
44-
4565# SEIR model with exposed compartment
46- let
66+ @testset " SEIR Model Correctness " begin
4767 β = 0.3 / 1000.0
4868 σ = 0.2
4969 ν = 0.01
5070 p = (β, σ, ν)
5171
52- regular_rate = (out, u, p, t) -> begin
53- out[1 ] = p[1 ] * u[1 ] * u[3 ] # β*S*I (infection)
54- out[2 ] = p[2 ] * u[2 ] # σ*E (progression)
55- out[3 ] = p[3 ] * u[3 ] # ν*I (recovery)
56- end
57-
58- regular_c = (dc, u, p, t, counts, mark) -> begin
59- dc .= 0.0
60- dc[1 ] = - counts[1 ] # S: -infection
61- dc[2 ] = counts[1 ] - counts[2 ] # E: +infection - progression
62- dc[3 ] = counts[2 ] - counts[3 ] # I: +progression - recovery
63- dc[4 ] = counts[3 ] # R: +recovery
64- end
72+ # ConstantRateJump formulation for SSAStepper
73+ rate1 (u, p, t) = p[1 ] * u[1 ] * u[3 ] # β*S*I (infection)
74+ rate2 (u, p, t) = p[2 ] * u[2 ] # σ*E (progression)
75+ rate3 (u, p, t) = p[3 ] * u[3 ] # ν*I (recovery)
76+ affect1! (integrator) = (integrator. u[1 ] -= 1 ; integrator. u[2 ] += 1 ; nothing )
77+ affect2! (integrator) = (integrator. u[2 ] -= 1 ; integrator. u[3 ] += 1 ; nothing )
78+ affect3! (integrator) = (integrator. u[3 ] -= 1 ; integrator. u[4 ] += 1 ; nothing )
79+ jumps = (ConstantRateJump (rate1, affect1!), ConstantRateJump (rate2, affect2!), ConstantRateJump (rate3, affect3!))
6580
66- # Initial state
6781 u0 = [999.0 , 0.0 , 10.0 , 0.0 ] # S, E, I, R
6882 tspan = (0.0 , 250.0 )
69-
70- # Create JumpProblem
7183 prob_disc = DiscreteProblem (u0, tspan, p)
72- rj = RegularJump (regular_rate, regular_c, 3 )
73- jump_prob = JumpProblem (prob_disc, Direct (), rj; rng= StableRNG (12345 ))
84+ jump_prob = JumpProblem (prob_disc, Direct (), jumps... ; rng = rng)
7485
75- sol = solve ( EnsembleProblem (jump_prob), SimpleTauLeaping (), EnsembleSerial (); trajectories = Nsims, dt = 1.0 )
76- mean_simple = mean (sol . u[i][ end , end ] for i in 1 : Nsims)
86+ # Solve with SSAStepper
87+ sol_direct = solve ( EnsembleProblem (jump_prob), SSAStepper (), EnsembleSerial (); trajectories = Nsims)
7788
78- sol = solve (EnsembleProblem (jump_prob), SimpleAdaptiveTauLeaping (), EnsembleSerial (); trajectories = Nsims)
79- mean_adaptive = mean (sol. u[i][end ,end ] for i in 1 : Nsims)
80-
81- @test isapprox (mean_simple, mean_adaptive, rtol= 0.05 )
89+ # RegularJump formulation for TauLeaping methods
90+ regular_rate = (out, u, p, t) -> begin
91+ out[1 ] = p[1 ] * u[1 ] * u[3 ]
92+ out[2 ] = p[2 ] * u[2 ]
93+ out[3 ] = p[3 ] * u[3 ]
94+ end
95+ regular_c = (dc, u, p, t, counts, mark) -> begin
96+ dc .= 0.0
97+ dc[1 ] = - counts[1 ]
98+ dc[2 ] = counts[1 ] - counts[2 ]
99+ dc[3 ] = counts[2 ] - counts[3 ]
100+ dc[4 ] = counts[3 ]
101+ end
102+ rj = RegularJump (regular_rate, regular_c, 3 )
103+ jump_prob_tau = JumpProblem (prob_disc, Direct (), rj; rng = rng)
104+
105+ # Solve with SimpleTauLeaping (dt=0.1)
106+ sol_simple = solve (EnsembleProblem (jump_prob_tau), SimpleTauLeaping (), EnsembleSerial (); trajectories= Nsims, dt= 0.1 )
107+
108+ # Solve with SimpleAdaptiveTauLeaping
109+ sol_adaptive = solve (EnsembleProblem (jump_prob_tau), SimpleAdaptiveTauLeaping (), EnsembleSerial (); trajectories= Nsims)
110+
111+ # Compute mean trajectories at t = 0, 1, ..., 250
112+ t_points = 0 : 1.0 : 250.0
113+ mean_direct_S = [mean (sol_direct[i](t)[3 ] for i in 1 : Nsims) for t in t_points]
114+ mean_simple_S = [mean (sol_simple[i](t)[3 ] for i in 1 : Nsims) for t in t_points]
115+ mean_adaptive_S = [mean (sol_adaptive[i](t)[3 ] for i in 1 : Nsims) for t in t_points]
116+
117+ for i in 1 : 251
118+ @test isapprox (mean_direct_S[i], mean_simple_S[i], rtol= 0.10 )
119+ @test isapprox (mean_direct_S[i], mean_adaptive_S[i], rtol= 0.10 )
120+ end
82121end
0 commit comments