Skip to content

Commit 2b82dc8

Browse files
basic version of inplicit tau leap is done
1 parent b36f470 commit 2b82dc8

File tree

3 files changed

+41
-114
lines changed

3 files changed

+41
-114
lines changed

src/JumpProcesses.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ using Base.Threads: Threads, @threads
2121
using Base.FastMath: add_fast
2222
using Setfield: @set, @set!
2323

24+
using SimpleNonlinearSolve
25+
2426
# Import functions we extend from Base
2527
import Base: size, getindex, setindex!, length, similar, show, merge!, merge
2628

src/simple_regular_solve.jl

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ end
5858
SimpleImplicitTauLeaping(; epsilon=0.05) = SimpleImplicitTauLeaping(epsilon)
5959

6060
function compute_hor(nu)
61-
hor = zeros(Int, size(nu, 2))
61+
hor = zeros(Int64, size(nu, 2))
6262
for j in 1:size(nu, 2)
6363
hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1
6464
end
@@ -77,8 +77,8 @@ end
7777

7878
function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate)
7979
rate(rate_cache, u, p, t)
80-
mu = zeros(length(u))
81-
sigma2 = zeros(length(u))
80+
mu = zeros(Float64, length(u))
81+
sigma2 = zeros(Float64, length(u))
8282
tau = Inf
8383
for i in 1:length(u)
8484
for j in 1:size(nu, 2)
@@ -100,21 +100,20 @@ function compute_tau_implicit(u, rate_cache, nu, p, t, rate)
100100
for i in 1:length(u)
101101
sum_nu_a = 0.0
102102
for j in 1:size(nu, 2)
103-
if nu[i, j] < 0 # Only sum negative stoichiometry
103+
if nu[i, j] < 0
104104
sum_nu_a += abs(nu[i, j]) * rate_cache[j]
105105
end
106106
end
107-
if sum_nu_a > 0 && u[i] > 0 # Avoid division by zero
107+
if sum_nu_a > 0 && u[i] > 0
108108
tau = min(tau, u[i] / sum_nu_a)
109109
end
110110
end
111111
return tau
112112
end
113113

114114
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
115-
# Nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (k_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0
116-
function f(u_new)
117-
rate_new = zeros(Float64, numjumps)
115+
function f(u_new, p)
116+
rate_new = zeros(eltype(u_new), numjumps)
118117
rate(rate_new, u_new, p, t_prev + tau)
119118
residual = u_new - u_prev
120119
for j in 1:numjumps
@@ -123,41 +122,14 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate,
123122
return residual
124123
end
125124

126-
# Numerical Jacobian
127-
function compute_jacobian(u_new)
128-
n = length(u_new)
129-
J = zeros(Float64, n, n)
130-
h = 1e-6
131-
f_u = f(u_new)
132-
for j in 1:n
133-
u_pert = copy(u_new)
134-
u_pert[j] += h
135-
f_pert = f(u_pert)
136-
J[:, j] = (f_pert - f_u) / h
137-
end
138-
return J
139-
end
125+
u_new = float.(u_prev + sum(nu[:, j] * counts[j] for j in 1:numjumps))
126+
prob = NonlinearProblem{false}(f, u_new, p)
127+
sol = solve(prob, SimpleNewtonRaphson(), abstol=1e-6, maxiters=100)
140128

141-
# Inline Newton-Raphson
142-
u_new = float.(u_prev + sum(nu[:, j] * counts[j] for j in 1:numjumps)) # Initial guess: explicit step
143-
tol = 1e-6
144-
maxiters = 100
145-
for iter in 1:maxiters
146-
F = f(u_new)
147-
if norm(F) < tol
148-
return round.(Int, max.(u_new, 0.0)) # Converged
149-
end
150-
J = compute_jacobian(u_new)
151-
if abs(det(J)) < 1e-10 # Check for singular Jacobian
152-
return nothing
153-
end
154-
delta = J \ F
155-
u_new -= delta
156-
if any(isnan.(u_new)) || any(isinf.(u_new))
157-
return nothing
158-
end
129+
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
130+
return nothing
159131
end
160-
return nothing # Failed to converge
132+
return round.(Int64, max.(sol.u, 0.0))
161133
end
162134

163135
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing, dtmin=1e-10, saveat=nothing)
@@ -200,7 +172,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
200172
while t[end] < t_end
201173
u_prev = u[end]
202174
t_prev = t[end]
203-
# Recompute stoichiometry
204175
for j in 1:numjumps
205176
fill!(counts_temp, 0)
206177
counts_temp[j] = 1
@@ -210,11 +181,10 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
210181
rate(rate_cache, u_prev, p, t_prev)
211182
tau_prime = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate)
212183
tau_double_prime = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, rate)
213-
# Cao et al. (2007): Use tau_prime for explicit, tau_double_prime for implicit
214184
use_implicit = false
215-
tau = tau_prime # Default to explicit
216-
if tau_double_prime < tau_prime && any(u_prev .< 10) # Implicit if populations are low
217-
tau = tau_double_prime
185+
tau = tau_prime
186+
if any(u_prev .< 10)
187+
tau = min(tau_double_prime, tau_prime) # Tighter cap for accuracy
218188
use_implicit = true
219189
end
220190
tau = max(tau, dtmin)
@@ -230,11 +200,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
230200
if use_implicit
231201
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
232202
if u_new === nothing || any(u_new .< 0)
233-
tau /= 2 # Halve tau if implicit fails or produces negative populations
203+
tau /= 2
234204
continue
235205
end
236206
elseif any(u_new .< 0)
237-
tau /= 2 # Halve tau if explicit produces negative populations
207+
tau /= 2
238208
continue
239209
end
240210
u_new = max.(u_new, 0)

test/regular_jumps.jl

Lines changed: 21 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,76 +5,31 @@ rng = StableRNG(12345)
55

66
Nsims = 10
77

8-
# @testset "SIR Model Correctness" begin
9-
β = 0.1 / 1000.0
10-
ν = 0.01
11-
influx_rate = 1.0
12-
p = (β, ν, influx_rate)
138

14-
u0 = [999, 10, 0] # Integer initial conditions
15-
tspan = (0.0, 250.0)
16-
prob_disc = DiscreteProblem(u0, tspan, p)
9+
β = 0.1 / 1000.0
10+
ν = 0.01
11+
influx_rate = 1.0
12+
p = (β, ν, influx_rate)
1713

18-
regular_rate = (out, u, p, t) -> begin
19-
out[1] = p[1] * u[1] * u[2]
20-
out[2] = p[2] * u[2]
21-
out[3] = p[3]
22-
end
23-
regular_c = (dc, u, p, t, counts, mark) -> begin
24-
dc .= 0
25-
dc[1] = -counts[1] + counts[3]
26-
dc[2] = counts[1] - counts[2]
27-
dc[3] = counts[2]
28-
end
29-
rj = RegularJump(regular_rate, regular_c, 3)
30-
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
14+
function regular_rate(out, u, p, t)
15+
out[1] = p[1] * u[1] * u[2]
16+
out[2] = p[2] * u[2]
17+
out[3] = p[3]
18+
end
3119

32-
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims)
20+
regular_c = (dc, u, p, t, counts, mark) -> begin
21+
dc .= 0
22+
dc[1] = -counts[1] + counts[3]
23+
dc[2] = counts[1] - counts[2]
24+
dc[3] = counts[2]
25+
end
3326

34-
# end
27+
u0 = [999, 5, 0]
28+
tspan = (0.0, 250.0)
29+
prob_disc = DiscreteProblem(u0, tspan, p)
3530

36-
# @testset "SEIR Model Correctness" begin
37-
β = 0.3 / 1000.0
38-
σ = 0.2
39-
ν = 0.01
40-
p = (β, σ, ν)
4131

42-
rate1(u, p, t) = p[1] * u[1] * u[3]
43-
rate2(u, p, t) = p[2] * u[2]
44-
rate3(u, p, t) = p[3] * u[3]
45-
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
46-
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
47-
affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing)
48-
jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!))
32+
rj = RegularJump(regular_rate, regular_c, 3)
33+
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
4934

50-
u0 = [999, 0, 10, 0] # Integer initial conditions
51-
tspan = (0.0, 250.0)
52-
prob_disc = DiscreteProblem(u0, tspan, p)
53-
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345))
54-
55-
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
56-
57-
regular_rate = (out, u, p, t) -> begin
58-
out[1] = p[1] * u[1] * u[3]
59-
out[2] = p[2] * u[2]
60-
out[3] = p[3] * u[3]
61-
end
62-
regular_c = (dc, u, p, t, counts, mark) -> begin
63-
dc .= 0
64-
dc[1] = -counts[1]
65-
dc[2] = counts[1] - counts[2]
66-
dc[3] = counts[2] - counts[3]
67-
dc[4] = counts[3]
68-
end
69-
rj = RegularJump(regular_rate, regular_c, 3)
70-
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
71-
72-
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
73-
74-
t_points = 0:1.0:250.0
75-
mean_direct_R = [mean(sol_direct[i](t)[4] for i in 1:Nsims) for t in t_points]
76-
mean_implicit_R = [mean(sol_implicit[i](t)[4] for i in 1:Nsims) for t in t_points]
77-
78-
max_error_implicit = maximum(abs.(mean_direct_R .- mean_implicit_R))
79-
# @test max_error_implicit < 0.01 * mean(mean_direct_R)
80-
# end
35+
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)

0 commit comments

Comments
 (0)