@@ -3,110 +3,80 @@ using Test, LinearAlgebra, Statistics
33using StableRNGs
44rng = StableRNG (12345 )
55
6- function regular_rate (out, u, p, t)
7- out[1 ] = (0.1 / 1000.0 ) * u[1 ] * u[2 ]
8- out[2 ] = 0.01 u[2 ]
9- end
6+ Nsims = 8000
7+
8+ # SIR model with influx
9+ let
10+ β = 0.1 / 1000.0
11+ ν = 0.01
12+ influx_rate = 1.0
13+ p = (β, ν, influx_rate)
14+
15+ regular_rate = (out, u, p, t) -> begin
16+ out[1 ] = p[1 ] * u[1 ] * u[2 ] # β*S*I (infection)
17+ out[2 ] = p[2 ] * u[2 ] # ν*I (recovery)
18+ out[3 ] = p[3 ] # influx_rate
19+ end
20+
21+ regular_c = (dc, u, p, t, counts, mark) -> begin
22+ dc .= 0.0
23+ dc[1 ] = - counts[1 ] + counts[3 ] # S: -infection + influx
24+ dc[2 ] = counts[1 ] - counts[2 ] # I: +infection - recovery
25+ dc[3 ] = counts[2 ] # R: +recovery
26+ end
27+
28+ u0 = [999.0 , 10.0 , 0.0 ] # S, I, R
29+ tspan = (0.0 , 250.0 )
30+
31+ prob_disc = DiscreteProblem (u0, tspan, p)
32+ rj = RegularJump (regular_rate, regular_c, 3 )
33+ jump_prob = JumpProblem (prob_disc, Direct (), rj)
34+
35+ sol = solve (EnsembleProblem (jump_prob), SimpleTauLeaping (), EnsembleSerial (); trajectories = Nsims, dt = 1.0 )
36+ mean_simple = mean (sol. u[i][1 ,end ] for i in 1 : Nsims)
1037
11- const dc = zeros (3 , 2 )
12- dc[1 , 1 ] = - 1
13- dc[2 , 1 ] = 1
14- dc[2 , 2 ] = - 1
15- dc[3 , 2 ] = 1
38+ sol = solve (EnsembleProblem (jump_prob), SimpleImplicitTauLeaping (), EnsembleSerial (); trajectories = Nsims)
39+ mean_implicit = mean (sol. u[i][1 ,end ] for i in 1 : Nsims)
1640
17- function regular_c (du, u, p, t, counts, mark)
18- mul! (du, dc, counts)
41+ @test isapprox (mean_simple, mean_implicit, rtol= 0.05 )
1942end
2043
21- rj = RegularJump (regular_rate, regular_c, 2 )
22- jumps = JumpSet (rj)
23- prob = DiscreteProblem ([999 , 1 , 0 ], (0.0 , 250.0 ))
24- jump_prob = JumpProblem (prob, PureLeaping (), rj; rng)
25- sol = solve (jump_prob, SimpleTauLeaping (); dt = 1.0 )
26-
27- # Test PureLeaping aggregator functionality
28- @testset " PureLeaping Aggregator Tests" begin
29- # Test with MassActionJump
30- u0 = [10 , 5 , 0 ]
31- tspan = (0.0 , 10.0 )
32- p = [0.1 , 0.2 ]
33- prob = DiscreteProblem (u0, tspan, p)
34-
35- # Create MassActionJump
36- reactant_stoich = [[1 => 1 ], [1 => 2 ]]
37- net_stoich = [[1 => - 1 , 2 => 1 ], [1 => - 2 , 3 => 1 ]]
38- rates = [0.1 , 0.05 ]
39- maj = MassActionJump (rates, reactant_stoich, net_stoich)
40-
41- # Test PureLeaping JumpProblem creation
42- jp_pure = JumpProblem (prob, PureLeaping (), JumpSet (maj); rng)
43- @test jp_pure. aggregator isa PureLeaping
44- @test jp_pure. discrete_jump_aggregation === nothing
45- @test jp_pure. massaction_jump != = nothing
46- @test length (jp_pure. jump_callback. discrete_callbacks) == 0
47-
48- # Test with ConstantRateJump
49- rate (u, p, t) = p[1 ] * u[1 ]
50- affect! (integrator) = (integrator. u[1 ] -= 1 ; integrator. u[3 ] += 1 )
51- crj = ConstantRateJump (rate, affect!)
52-
53- jp_pure_crj = JumpProblem (prob, PureLeaping (), JumpSet (crj); rng)
54- @test jp_pure_crj. aggregator isa PureLeaping
55- @test jp_pure_crj. discrete_jump_aggregation === nothing
56- @test length (jp_pure_crj. constant_jumps) == 1
57-
58- # Test with VariableRateJump
59- vrate (u, p, t) = t * p[1 ] * u[1 ]
60- vaffect! (integrator) = (integrator. u[1 ] -= 1 ; integrator. u[3 ] += 1 )
61- vrj = VariableRateJump (vrate, vaffect!)
62-
63- jp_pure_vrj = JumpProblem (prob, PureLeaping (), JumpSet (vrj); rng)
64- @test jp_pure_vrj. aggregator isa PureLeaping
65- @test jp_pure_vrj. discrete_jump_aggregation === nothing
66- @test length (jp_pure_vrj. variable_jumps) == 1
67-
68- # Test with RegularJump
69- function rj_rate (out, u, p, t)
70- out[1 ] = p[1 ] * u[1 ]
44+
45+ # SEIR model with exposed compartment
46+ let
47+ β = 0.3 / 1000.0
48+ σ = 0.2
49+ ν = 0.01
50+ p = (β, σ, ν)
51+
52+ regular_rate = (out, u, p, t) -> begin
53+ out[1 ] = p[1 ] * u[1 ] * u[3 ] # β*S*I (infection)
54+ out[2 ] = p[2 ] * u[2 ] # σ*E (progression)
55+ out[3 ] = p[3 ] * u[3 ] # ν*I (recovery)
7156 end
72-
73- rj_dc = zeros ( 3 , 1 )
74- rj_dc[ 1 , 1 ] = - 1
75- rj_dc[ 3 , 1 ] = 1
76-
77- function rj_c (du, u, p, t, counts, mark)
78- mul! (du, rj_dc, counts)
57+
58+ regular_c = (dc, u, p, t, counts, mark) -> begin
59+ dc . = 0.0
60+ dc[ 1 ] = - counts[ 1 ] # S: -infection
61+ dc[ 2 ] = counts[ 1 ] - counts[ 2 ] # E: +infection - progression
62+ dc[ 3 ] = counts[ 2 ] - counts[ 3 ] # I: +progression - recovery
63+ dc[ 4 ] = counts[ 3 ] # R: +recovery
7964 end
80-
81- regj = RegularJump (rj_rate, rj_c, 1 )
82-
83- jp_pure_regj = JumpProblem (prob, PureLeaping (), JumpSet (regj); rng)
84- @test jp_pure_regj. aggregator isa PureLeaping
85- @test jp_pure_regj. discrete_jump_aggregation === nothing
86- @test jp_pure_regj. regular_jump != = nothing
87-
88- # Test mixed jump types
89- mixed_jumps = JumpSet (; massaction_jumps = maj, constant_jumps = (crj,),
90- variable_jumps = (vrj,), regular_jumps = regj)
91- jp_pure_mixed = JumpProblem (prob, PureLeaping (), mixed_jumps; rng)
92- @test jp_pure_mixed. aggregator isa PureLeaping
93- @test jp_pure_mixed. discrete_jump_aggregation === nothing
94- @test jp_pure_mixed. massaction_jump != = nothing
95- @test length (jp_pure_mixed. constant_jumps) == 1
96- @test length (jp_pure_mixed. variable_jumps) == 1
97- @test jp_pure_mixed. regular_jump != = nothing
98-
99- # Test spatial system error
100- spatial_sys = CartesianGrid ((2 , 2 ))
101- hopping_consts = [1.0 ]
102- @test_throws ErrorException JumpProblem (prob, PureLeaping (), JumpSet (maj); rng,
103- spatial_system = spatial_sys)
104- @test_throws ErrorException JumpProblem (prob, PureLeaping (), JumpSet (maj); rng,
105- hopping_constants = hopping_consts)
106-
107- # Test MassActionJump with parameter mapping
108- maj_params = MassActionJump (reactant_stoich, net_stoich; param_idxs = [1 , 2 ])
109- jp_params = JumpProblem (prob, PureLeaping (), JumpSet (maj_params); rng)
110- scaled_rates = [p[1 ], p[2 ]/ 2 ]
111- @test jp_params. massaction_jump. scaled_rates == scaled_rates
65+
66+ # Initial state
67+ u0 = [999.0 , 0.0 , 10.0 , 0.0 ] # S, E, I, R
68+ tspan = (0.0 , 250.0 )
69+
70+ # Create JumpProblem
71+ prob_disc = DiscreteProblem (u0, tspan, p)
72+ rj = RegularJump (regular_rate, regular_c, 3 )
73+ jump_prob = JumpProblem (prob_disc, Direct (), rj; rng= StableRNG (12345 ))
74+
75+ sol = solve (EnsembleProblem (jump_prob), SimpleTauLeaping (), EnsembleSerial (); trajectories = Nsims, dt = 1.0 )
76+ mean_simple = mean (sol. u[i][end ,end ] for i in 1 : Nsims)
77+
78+ sol = solve (EnsembleProblem (jump_prob), SimpleImplicitTauLeaping (), EnsembleSerial (); trajectories = Nsims)
79+ mean_implicit = mean (sol. u[i][end ,end ] for i in 1 : Nsims)
80+
81+ @test isapprox (mean_simple, mean_implicit, rtol= 0.05 )
11282end
0 commit comments