@@ -15,7 +15,7 @@ Nsims = 1000
1515 # ConstantRateJump formulation for SSAStepper
1616 rate1 (u, p, t) = p[1 ] * u[1 ] * u[2 ] # β*S*I (infection)
1717 rate2 (u, p, t) = p[2 ] * u[2 ] # ν*I (recovery)
18- rate3 (u, p, t) = p[3 ] # influx_rate
18+ rate3 (u, p, t) = p[3 ] # influx_rate (S influx)
1919 affect1! (integrator) = (integrator. u[1 ] -= 1 ; integrator. u[2 ] += 1 ; nothing )
2020 affect2! (integrator) = (integrator. u[2 ] -= 1 ; integrator. u[3 ] += 1 ; nothing )
2121 affect3! (integrator) = (integrator. u[1 ] += 1 ; nothing )
@@ -24,41 +24,49 @@ Nsims = 1000
2424 u0 = [999.0 , 10.0 , 0.0 ] # S, I, R
2525 tspan = (0.0 , 250.0 )
2626 prob_disc = DiscreteProblem (u0, tspan, p)
27- jump_prob = JumpProblem (prob_disc, Direct (), jumps... ; rng = rng)
27+ jump_prob = JumpProblem (prob_disc, Direct (), jumps... ; rng= rng)
2828
2929 # Solve with SSAStepper
30- sol_direct = solve (EnsembleProblem (jump_prob), SSAStepper (), EnsembleSerial (); trajectories= Nsims)
30+ sol_direct = solve (EnsembleProblem (jump_prob), SSAStepper (), EnsembleSerial (); trajectories= Nsims, saveat = 1.0 )
3131
32- # RegularJump formulation for TauLeaping methods
32+ # RegularJump formulation for SimpleTauLeaping
3333 regular_rate = (out, u, p, t) -> begin
3434 out[1 ] = p[1 ] * u[1 ] * u[2 ]
3535 out[2 ] = p[2 ] * u[2 ]
3636 out[3 ] = p[3 ]
3737 end
3838 regular_c = (dc, u, p, t, counts, mark) -> begin
39- dc .= 0.0
39+ dc .= 0
4040 dc[1 ] = - counts[1 ] + counts[3 ]
4141 dc[2 ] = counts[1 ] - counts[2 ]
4242 dc[3 ] = counts[2 ]
4343 end
4444 rj = RegularJump (regular_rate, regular_c, 3 )
45- jump_prob_tau = JumpProblem (prob_disc, PureLeaping (), rj; rng = rng)
45+ jump_prob_tau = JumpProblem (prob_disc, PureLeaping (), rj; rng= rng)
4646
47- # Solve with SimpleTauLeaping (dt=0.1)
47+ # Solve with SimpleTauLeaping
4848 sol_simple = solve (EnsembleProblem (jump_prob_tau), SimpleTauLeaping (), EnsembleSerial (); trajectories= Nsims, dt= 0.1 )
49-
49+
50+ # MassActionJump formulation for SimpleAdaptiveTauLeaping
51+ reactant_stoich = [[1 => 1 , 2 => 1 ], [2 => 1 ], Pair{Int,Int}[]]
52+ net_stoich = [[1 => - 1 , 2 => 1 ], [2 => - 1 , 3 => 1 ], [1 => 1 ]]
53+ param_idxs = [1 , 2 , 3 ]
54+ maj = MassActionJump (reactant_stoich, net_stoich; param_idxs= param_idxs)
55+ jump_prob_maj = JumpProblem (prob_disc, PureLeaping (), maj; rng= rng)
56+
5057 # Solve with SimpleAdaptiveTauLeaping
51- sol_adaptive = solve (EnsembleProblem (jump_prob_tau ), SimpleAdaptiveTauLeaping (), EnsembleSerial (); trajectories= Nsims, saveat = 1.0 )
58+ sol_adaptive = solve (EnsembleProblem (jump_prob_maj ), SimpleAdaptiveTauLeaping (), EnsembleSerial (); trajectories= Nsims, saveat= 1.0 )
5259
53- # Compute mean trajectories at t = 0, 1, ..., 250
60+ # Compute mean infected (I) trajectories
5461 t_points = 0 : 1.0 : 250.0
55- mean_direct_S = [mean (sol_direct[i](t)[2 ] for i in 1 : Nsims) for t in t_points]
56- mean_simple_S = [mean (sol_simple[i](t)[2 ] for i in 1 : Nsims) for t in t_points]
57- mean_adaptive_S = [mean (sol_adaptive[i](t)[2 ] for i in 1 : Nsims) for t in t_points]
62+ mean_direct_I = [mean (sol_direct[i](t)[2 ] for i in 1 : Nsims) for t in t_points]
63+ mean_simple_I = [mean (sol_simple[i](t)[2 ] for i in 1 : Nsims) for t in t_points]
64+ mean_adaptive_I = [mean (sol_adaptive[i](t)[2 ] for i in 1 : Nsims) for t in t_points]
5865
66+ # Test mean infected trajectories
5967 for i in 1 : 251
60- @test isapprox (mean_direct_S [i], mean_simple_S [i], rtol= 0.10 )
61- @test isapprox (mean_direct_S [i], mean_adaptive_S [i], rtol= 0.10 )
68+ @test isapprox (mean_direct_I [i], mean_simple_I [i], rtol= 0.10 )
69+ @test isapprox (mean_direct_I [i], mean_adaptive_I [i], rtol= 0.10 )
6270 end
6371end
6472
8189 u0 = [999.0 , 0.0 , 10.0 , 0.0 ] # S, E, I, R
8290 tspan = (0.0 , 250.0 )
8391 prob_disc = DiscreteProblem (u0, tspan, p)
84- jump_prob = JumpProblem (prob_disc, Direct (), jumps... ; rng = rng)
92+ jump_prob = JumpProblem (prob_disc, Direct (), jumps... ; rng= rng)
8593
8694 # Solve with SSAStepper
87- sol_direct = solve (EnsembleProblem (jump_prob), SSAStepper (), EnsembleSerial (); trajectories= Nsims)
95+ sol_direct = solve (EnsembleProblem (jump_prob), SSAStepper (), EnsembleSerial (); trajectories= Nsims, saveat = 1.0 )
8896
89- # RegularJump formulation for TauLeaping methods
97+ # RegularJump formulation for SimpleTauLeaping
9098 regular_rate = (out, u, p, t) -> begin
9199 out[1 ] = p[1 ] * u[1 ] * u[3 ]
92100 out[2 ] = p[2 ] * u[2 ]
@@ -100,22 +108,117 @@ end
100108 dc[4 ] = counts[3 ]
101109 end
102110 rj = RegularJump (regular_rate, regular_c, 3 )
103- jump_prob_tau = JumpProblem (prob_disc, PureLeaping (), rj; rng = rng)
111+ jump_prob_tau = JumpProblem (prob_disc, PureLeaping (), rj; rng= rng)
104112
105- # Solve with SimpleTauLeaping (dt=0.1)
113+ # Solve with SimpleTauLeaping
106114 sol_simple = solve (EnsembleProblem (jump_prob_tau), SimpleTauLeaping (), EnsembleSerial (); trajectories= Nsims, dt= 0.1 )
107-
115+
116+ # MassActionJump formulation for SimpleAdaptiveTauLeaping
117+ reactant_stoich = [[1 => 1 , 3 => 1 ], [2 => 1 ], [3 => 1 ]]
118+ net_stoich = [[1 => - 1 , 2 => 1 ], [2 => - 1 , 3 => 1 ], [3 => - 1 , 4 => 1 ]]
119+ param_idxs = [1 , 2 , 3 ]
120+ maj = MassActionJump (reactant_stoich, net_stoich; param_idxs= param_idxs)
121+ jump_prob_maj = JumpProblem (prob_disc, PureLeaping (), maj; rng= rng)
122+
108123 # Solve with SimpleAdaptiveTauLeaping
109- sol_adaptive = solve (EnsembleProblem (jump_prob_tau ), SimpleAdaptiveTauLeaping (), EnsembleSerial (); trajectories= Nsims, saveat = 1.0 )
124+ sol_adaptive = solve (EnsembleProblem (jump_prob_maj ), SimpleAdaptiveTauLeaping (), EnsembleSerial (); trajectories= Nsims, saveat= 1.0 )
110125
111- # Compute mean trajectories at t = 0, 1, ..., 250
126+ # Compute mean infected (I) trajectories
112127 t_points = 0 : 1.0 : 250.0
113- mean_direct_S = [mean (sol_direct[i](t)[3 ] for i in 1 : Nsims) for t in t_points]
114- mean_simple_S = [mean (sol_simple[i](t)[3 ] for i in 1 : Nsims) for t in t_points]
115- mean_adaptive_S = [mean (sol_adaptive[i](t)[3 ] for i in 1 : Nsims) for t in t_points]
128+ mean_direct_I = [mean (sol_direct[i](t)[3 ] for i in 1 : Nsims) for t in t_points]
129+ mean_simple_I = [mean (sol_simple[i](t)[3 ] for i in 1 : Nsims) for t in t_points]
130+ mean_adaptive_I = [mean (sol_adaptive[i](t)[3 ] for i in 1 : Nsims) for t in t_points]
116131
132+ # Test mean infected trajectories
117133 for i in 1 : 251
118- @test isapprox (mean_direct_S[i], mean_simple_S[i], rtol= 0.10 )
119- @test isapprox (mean_direct_S[i], mean_adaptive_S[i], rtol= 0.10 )
134+ @test isapprox (mean_direct_I[i], mean_simple_I[i], rtol= 0.10 )
135+ @test isapprox (mean_direct_I[i], mean_adaptive_I[i], rtol= 0.10 )
136+ end
137+ end
138+
139+ # Test PureLeaping aggregator functionality
140+ @testset " PureLeaping Aggregator Tests" begin
141+ # Test with MassActionJump
142+ u0 = [10 , 5 , 0 ]
143+ tspan = (0.0 , 10.0 )
144+ p = [0.1 , 0.2 ]
145+ prob = DiscreteProblem (u0, tspan, p)
146+
147+ # Create MassActionJump
148+ reactant_stoich = [[1 => 1 ], [1 => 2 ]]
149+ net_stoich = [[1 => - 1 , 2 => 1 ], [1 => - 2 , 3 => 1 ]]
150+ rates = [0.1 , 0.05 ]
151+ maj = MassActionJump (rates, reactant_stoich, net_stoich)
152+
153+ # Test PureLeaping JumpProblem creation
154+ jp_pure = JumpProblem (prob, PureLeaping (), JumpSet (maj))
155+ @test jp_pure. aggregator isa PureLeaping
156+ @test jp_pure. discrete_jump_aggregation === nothing
157+ @test jp_pure. massaction_jump != = nothing
158+ @test length (jp_pure. jump_callback. discrete_callbacks) == 0
159+
160+ # Test with ConstantRateJump
161+ rate (u, p, t) = p[1 ] * u[1 ]
162+ affect! (integrator) = (integrator. u[1 ] -= 1 ; integrator. u[3 ] += 1 )
163+ crj = ConstantRateJump (rate, affect!)
164+
165+ jp_pure_crj = JumpProblem (prob, PureLeaping (), JumpSet (crj))
166+ @test jp_pure_crj. aggregator isa PureLeaping
167+ @test jp_pure_crj. discrete_jump_aggregation === nothing
168+ @test length (jp_pure_crj. constant_jumps) == 1
169+
170+ # Test with VariableRateJump
171+ vrate (u, p, t) = t * p[1 ] * u[1 ]
172+ vaffect! (integrator) = (integrator. u[1 ] -= 1 ; integrator. u[3 ] += 1 )
173+ vrj = VariableRateJump (vrate, vaffect!)
174+
175+ jp_pure_vrj = JumpProblem (prob, PureLeaping (), JumpSet (vrj))
176+ @test jp_pure_vrj. aggregator isa PureLeaping
177+ @test jp_pure_vrj. discrete_jump_aggregation === nothing
178+ @test length (jp_pure_vrj. variable_jumps) == 1
179+
180+ # Test with RegularJump
181+ function rj_rate (out, u, p, t)
182+ out[1 ] = p[1 ] * u[1 ]
183+ end
184+
185+ rj_dc = zeros (3 , 1 )
186+ rj_dc[1 , 1 ] = - 1
187+ rj_dc[3 , 1 ] = 1
188+
189+ function rj_c (du, u, p, t, counts, mark)
190+ mul! (du, rj_dc, counts)
120191 end
192+
193+ regj = RegularJump (rj_rate, rj_c, 1 )
194+
195+ jp_pure_regj = JumpProblem (prob, PureLeaping (), JumpSet (regj))
196+ @test jp_pure_regj. aggregator isa PureLeaping
197+ @test jp_pure_regj. discrete_jump_aggregation === nothing
198+ @test jp_pure_regj. regular_jump != = nothing
199+
200+ # Test mixed jump types
201+ mixed_jumps = JumpSet (; massaction_jumps = maj, constant_jumps = (crj,),
202+ variable_jumps = (vrj,), regular_jumps = regj)
203+ jp_pure_mixed = JumpProblem (prob, PureLeaping (), mixed_jumps)
204+ @test jp_pure_mixed. aggregator isa PureLeaping
205+ @test jp_pure_mixed. discrete_jump_aggregation === nothing
206+ @test jp_pure_mixed. massaction_jump != = nothing
207+ @test length (jp_pure_mixed. constant_jumps) == 1
208+ @test length (jp_pure_mixed. variable_jumps) == 1
209+ @test jp_pure_mixed. regular_jump != = nothing
210+
211+ # Test spatial system error
212+ spatial_sys = CartesianGrid ((2 , 2 ))
213+ hopping_consts = [1.0 ]
214+ @test_throws ErrorException JumpProblem (prob, PureLeaping (), JumpSet (maj);
215+ spatial_system = spatial_sys)
216+ @test_throws ErrorException JumpProblem (prob, PureLeaping (), JumpSet (maj);
217+ hopping_constants = hopping_consts)
218+
219+ # Test MassActionJump with parameter mapping
220+ maj_params = MassActionJump (reactant_stoich, net_stoich; param_idxs = [1 , 2 ])
221+ jp_params = JumpProblem (prob, PureLeaping (), JumpSet (maj_params))
222+ scaled_rates = [p[1 ], p[2 ]/ 2 ]
223+ @test jp_params. massaction_jump. scaled_rates == scaled_rates
121224end
0 commit comments