Skip to content

Commit 91ff58b

Browse files
update project.toml
1 parent 4699990 commit 91ff58b

File tree

3 files changed

+81
-32
lines changed

3 files changed

+81
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1313
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
16-
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1716
PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
1817
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1918
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2019
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2120
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2221
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
22+
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2323
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2424
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2525
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

src/simple_regular_solve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate,
263263

264264
# Solve the nonlinear system
265265
prob = NonlinearProblem(f, u_new, nothing)
266-
sol = solve(prob, NewtonRaphson())
266+
sol = solve(prob, SimpleNewtonRaphson())
267267

268268
# Check for convergence and numerical stability
269269
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))

test/regular_jumps.jl

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,90 @@
11
using JumpProcesses, DiffEqBase
2-
using Test, LinearAlgebra, Statistics
3-
using StableRNGs
2+
using Test, LinearAlgebra
3+
using StableRNGs, Plots
44
rng = StableRNG(12345)
55

6-
function regular_rate(out, u, p, t)
7-
out[1] = (0.1 / 1000.0) * u[1] * u[2]
8-
out[2] = 0.01u[2]
9-
end
6+
Nsims = 8000
107

11-
function regular_c(dc, u, p, t, mark)
12-
dc[1, 1] = -1
13-
dc[2, 1] = 1
14-
dc[2, 2] = -1
15-
dc[3, 2] = 1
16-
end
8+
# SIR model with influx
9+
let
10+
β = 0.1 / 1000.0
11+
ν = 0.01
12+
influx_rate = 1.0
13+
p = (β, ν, influx_rate)
14+
15+
regular_rate = (out, u, p, t) -> begin
16+
out[1] = p[1] * u[1] * u[2] # β*S*I (infection)
17+
out[2] = p[2] * u[2] # ν*I (recovery)
18+
out[3] = p[3] # influx_rate
19+
end
20+
21+
regular_c = (dc, u, p, t, counts, mark) -> begin
22+
dc .= 0.0
23+
dc[1] = -counts[1] + counts[3] # S: -infection + influx
24+
dc[2] = counts[1] - counts[2] # I: +infection - recovery
25+
dc[3] = counts[2] # R: +recovery
26+
end
1727

18-
dc = zeros(3, 2)
28+
u0 = [999.0, 10.0, 0.0] # S, I, R
29+
tspan = (0.0, 250.0)
1930

20-
rj = RegularJump(regular_rate, regular_c, dc; constant_c = true)
21-
jumps = JumpSet(rj)
31+
prob_disc = DiscreteProblem(u0, tspan, p)
32+
rj = RegularJump(regular_rate, regular_c, 3)
33+
jump_prob = JumpProblem(prob_disc, Direct(), rj)
2234

23-
prob = DiscreteProblem([999.0, 1.0, 0.0], (0.0, 250.0))
24-
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
25-
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
35+
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
36+
mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims)
2637

27-
const _dc = zeros(3, 2)
28-
dc[1, 1] = -1
29-
dc[2, 1] = 1
30-
dc[2, 2] = -1
31-
dc[3, 2] = 1
38+
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
39+
mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims)
3240

33-
function regular_c(du, u, p, t, counts, mark)
34-
mul!(du, dc, counts)
41+
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
42+
mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims)
43+
44+
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
45+
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
3546
end
3647

37-
rj = RegularJump(regular_rate, regular_c, 2)
38-
jumps = JumpSet(rj)
39-
prob = DiscreteProblem([999, 1, 0], (0.0, 250.0))
40-
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
41-
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
48+
49+
# SEIR model with exposed compartment
50+
let
51+
β = 0.3 / 1000.0
52+
σ = 0.2
53+
ν = 0.01
54+
p = (β, σ, ν)
55+
56+
regular_rate = (out, u, p, t) -> begin
57+
out[1] = p[1] * u[1] * u[3] # β*S*I (infection)
58+
out[2] = p[2] * u[2] # σ*E (progression)
59+
out[3] = p[3] * u[3] # ν*I (recovery)
60+
end
61+
62+
regular_c = (dc, u, p, t, counts, mark) -> begin
63+
dc .= 0.0
64+
dc[1] = -counts[1] # S: -infection
65+
dc[2] = counts[1] - counts[2] # E: +infection - progression
66+
dc[3] = counts[2] - counts[3] # I: +progression - recovery
67+
dc[4] = counts[3] # R: +recovery
68+
end
69+
70+
# Initial state
71+
u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R
72+
tspan = (0.0, 250.0)
73+
74+
# Create JumpProblem
75+
prob_disc = DiscreteProblem(u0, tspan, p)
76+
rj = RegularJump(regular_rate, regular_c, 3)
77+
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
78+
79+
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
80+
mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims)
81+
82+
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
83+
mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims)
84+
85+
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
86+
mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims)
87+
88+
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
89+
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
90+
end

0 commit comments

Comments
 (0)