Skip to content

Commit 6f868df

Browse files
poiss change
1 parent b5226f7 commit 6f868df

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

src/simple_regular_solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate,
137137

138138
# Solve the nonlinear system
139139
prob = NonlinearProblem(f, u_new, nothing)
140-
sol = solve(prob, SimpleNewtonRaphson(), tol=1e-6)
140+
sol = solve(prob, SimpleNewtonRaphson())
141141

142142
# Check for convergence and numerical stability
143143
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
@@ -205,7 +205,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
205205
tau = saveat_times[save_idx] - t_prev
206206
end
207207
end
208-
counts .= rand(rng, Poisson.(max.(rate_cache * tau, 0.0)))
208+
counts .= counts .= pois_rand.((rng,), max.(rate_cache * tau, 0.0))
209209
c(du, u_prev, p, t_prev, counts, nothing)
210210
u_new = u_prev + du
211211
if tau_prime <= tau_double_prime / 10.0

test/regular_jumps.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,17 @@ using Test, LinearAlgebra, Statistics
33
using StableRNGs
44
rng = StableRNG(12345)
55

6-
Nsims = 8000
6+
Nsims = 10
77

8-
@testset "SIR Model Correctness" begin
8+
# @testset "SIR Model Correctness" begin
99
β = 0.1 / 1000.0
1010
ν = 0.01
1111
influx_rate = 1.0
1212
p = (β, ν, influx_rate)
1313

14-
rate1(u, p, t) = p[1] * u[1] * u[2]
15-
rate2(u, p, t) = p[2] * u[2]
16-
rate3(u, p, t) = p[3]
17-
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
18-
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
19-
affect3!(integrator) = (integrator.u[1] += 1; nothing)
20-
jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!))
21-
2214
u0 = [999, 10, 0] # Integer initial conditions
2315
tspan = (0.0, 250.0)
2416
prob_disc = DiscreteProblem(u0, tspan, p)
25-
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345))
26-
27-
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
2817

2918
regular_rate = (out, u, p, t) -> begin
3019
out[1] = p[1] * u[1] * u[2]
@@ -41,14 +30,15 @@ Nsims = 8000
4130
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
4231

4332
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
33+
plot(sol_implicit)
4434

4535
t_points = 0:1.0:250.0
4636
mean_direct_S = [mean(sol_direct[i](t)[1] for i in 1:Nsims) for t in t_points]
4737
mean_implicit_S = [mean(sol_implicit[i](t)[1] for i in 1:Nsims) for t in t_points]
4838

4939
max_error_implicit = maximum(abs.(mean_direct_S .- mean_implicit_S))
5040
@test max_error_implicit < 0.01 * mean(mean_direct_S)
51-
end
41+
# end
5242

5343
@testset "SEIR Model Correctness" begin
5444
β = 0.3 / 1000.0

0 commit comments

Comments
 (0)