Skip to content

Commit bbe9dc5

Browse files
basic version of inplicit tau leap is done
1 parent 439bb7d commit bbe9dc5

File tree

3 files changed

+41
-114
lines changed

3 files changed

+41
-114
lines changed

src/JumpProcesses.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ using Base.Threads: Threads, @threads
2020
using Base.FastMath: add_fast
2121
using Setfield: @set, @set!
2222

23+
using SimpleNonlinearSolve
24+
2325
# Import functions we extend from Base
2426
import Base: size, getindex, setindex!, length, similar, show, merge!, merge
2527

src/simple_regular_solve.jl

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ end
6969
SimpleImplicitTauLeaping(; epsilon=0.05) = SimpleImplicitTauLeaping(epsilon)
7070

7171
function compute_hor(nu)
72-
hor = zeros(Int, size(nu, 2))
72+
hor = zeros(Int64, size(nu, 2))
7373
for j in 1:size(nu, 2)
7474
hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1
7575
end
@@ -88,8 +88,8 @@ end
8888

8989
function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate)
9090
rate(rate_cache, u, p, t)
91-
mu = zeros(length(u))
92-
sigma2 = zeros(length(u))
91+
mu = zeros(Float64, length(u))
92+
sigma2 = zeros(Float64, length(u))
9393
tau = Inf
9494
for i in 1:length(u)
9595
for j in 1:size(nu, 2)
@@ -111,21 +111,20 @@ function compute_tau_implicit(u, rate_cache, nu, p, t, rate)
111111
for i in 1:length(u)
112112
sum_nu_a = 0.0
113113
for j in 1:size(nu, 2)
114-
if nu[i, j] < 0 # Only sum negative stoichiometry
114+
if nu[i, j] < 0
115115
sum_nu_a += abs(nu[i, j]) * rate_cache[j]
116116
end
117117
end
118-
if sum_nu_a > 0 && u[i] > 0 # Avoid division by zero
118+
if sum_nu_a > 0 && u[i] > 0
119119
tau = min(tau, u[i] / sum_nu_a)
120120
end
121121
end
122122
return tau
123123
end
124124

125125
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
126-
# Nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (k_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0
127-
function f(u_new)
128-
rate_new = zeros(Float64, numjumps)
126+
function f(u_new, p)
127+
rate_new = zeros(eltype(u_new), numjumps)
129128
rate(rate_new, u_new, p, t_prev + tau)
130129
residual = u_new - u_prev
131130
for j in 1:numjumps
@@ -134,41 +133,14 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate,
134133
return residual
135134
end
136135

137-
# Numerical Jacobian
138-
function compute_jacobian(u_new)
139-
n = length(u_new)
140-
J = zeros(Float64, n, n)
141-
h = 1e-6
142-
f_u = f(u_new)
143-
for j in 1:n
144-
u_pert = copy(u_new)
145-
u_pert[j] += h
146-
f_pert = f(u_pert)
147-
J[:, j] = (f_pert - f_u) / h
148-
end
149-
return J
150-
end
136+
u_new = float.(u_prev + sum(nu[:, j] * counts[j] for j in 1:numjumps))
137+
prob = NonlinearProblem{false}(f, u_new, p)
138+
sol = solve(prob, SimpleNewtonRaphson(), abstol=1e-6, maxiters=100)
151139

152-
# Inline Newton-Raphson
153-
u_new = float.(u_prev + sum(nu[:, j] * counts[j] for j in 1:numjumps)) # Initial guess: explicit step
154-
tol = 1e-6
155-
maxiters = 100
156-
for iter in 1:maxiters
157-
F = f(u_new)
158-
if norm(F) < tol
159-
return round.(Int, max.(u_new, 0.0)) # Converged
160-
end
161-
J = compute_jacobian(u_new)
162-
if abs(det(J)) < 1e-10 # Check for singular Jacobian
163-
return nothing
164-
end
165-
delta = J \ F
166-
u_new -= delta
167-
if any(isnan.(u_new)) || any(isinf.(u_new))
168-
return nothing
169-
end
140+
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
141+
return nothing
170142
end
171-
return nothing # Failed to converge
143+
return round.(Int64, max.(sol.u, 0.0))
172144
end
173145

174146
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing, dtmin=1e-10, saveat=nothing)
@@ -211,7 +183,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
211183
while t[end] < t_end
212184
u_prev = u[end]
213185
t_prev = t[end]
214-
# Recompute stoichiometry
215186
for j in 1:numjumps
216187
fill!(counts_temp, 0)
217188
counts_temp[j] = 1
@@ -221,11 +192,10 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
221192
rate(rate_cache, u_prev, p, t_prev)
222193
tau_prime = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate)
223194
tau_double_prime = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, rate)
224-
# Cao et al. (2007): Use tau_prime for explicit, tau_double_prime for implicit
225195
use_implicit = false
226-
tau = tau_prime # Default to explicit
227-
if tau_double_prime < tau_prime && any(u_prev .< 10) # Implicit if populations are low
228-
tau = tau_double_prime
196+
tau = tau_prime
197+
if any(u_prev .< 10)
198+
tau = min(tau_double_prime, tau_prime) # Tighter cap for accuracy
229199
use_implicit = true
230200
end
231201
tau = max(tau, dtmin)
@@ -241,11 +211,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
241211
if use_implicit
242212
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
243213
if u_new === nothing || any(u_new .< 0)
244-
tau /= 2 # Halve tau if implicit fails or produces negative populations
214+
tau /= 2
245215
continue
246216
end
247217
elseif any(u_new .< 0)
248-
tau /= 2 # Halve tau if explicit produces negative populations
218+
tau /= 2
249219
continue
250220
end
251221
u_new = max.(u_new, 0)

test/regular_jumps.jl

Lines changed: 21 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,76 +5,31 @@ rng = StableRNG(12345)
55

66
Nsims = 10
77

8-
# @testset "SIR Model Correctness" begin
9-
β = 0.1 / 1000.0
10-
ν = 0.01
11-
influx_rate = 1.0
12-
p = (β, ν, influx_rate)
138

14-
u0 = [999, 10, 0] # Integer initial conditions
15-
tspan = (0.0, 250.0)
16-
prob_disc = DiscreteProblem(u0, tspan, p)
9+
β = 0.1 / 1000.0
10+
ν = 0.01
11+
influx_rate = 1.0
12+
p = (β, ν, influx_rate)
1713

18-
regular_rate = (out, u, p, t) -> begin
19-
out[1] = p[1] * u[1] * u[2]
20-
out[2] = p[2] * u[2]
21-
out[3] = p[3]
22-
end
23-
regular_c = (dc, u, p, t, counts, mark) -> begin
24-
dc .= 0
25-
dc[1] = -counts[1] + counts[3]
26-
dc[2] = counts[1] - counts[2]
27-
dc[3] = counts[2]
28-
end
29-
rj = RegularJump(regular_rate, regular_c, 3)
30-
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
14+
function regular_rate(out, u, p, t)
15+
out[1] = p[1] * u[1] * u[2]
16+
out[2] = p[2] * u[2]
17+
out[3] = p[3]
18+
end
3119

32-
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims)
20+
regular_c = (dc, u, p, t, counts, mark) -> begin
21+
dc .= 0
22+
dc[1] = -counts[1] + counts[3]
23+
dc[2] = counts[1] - counts[2]
24+
dc[3] = counts[2]
25+
end
3326

34-
# end
27+
u0 = [999, 5, 0]
28+
tspan = (0.0, 250.0)
29+
prob_disc = DiscreteProblem(u0, tspan, p)
3530

36-
# @testset "SEIR Model Correctness" begin
37-
β = 0.3 / 1000.0
38-
σ = 0.2
39-
ν = 0.01
40-
p = (β, σ, ν)
4131

42-
rate1(u, p, t) = p[1] * u[1] * u[3]
43-
rate2(u, p, t) = p[2] * u[2]
44-
rate3(u, p, t) = p[3] * u[3]
45-
affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing)
46-
affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing)
47-
affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing)
48-
jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!))
32+
rj = RegularJump(regular_rate, regular_c, 3)
33+
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
4934

50-
u0 = [999, 0, 10, 0] # Integer initial conditions
51-
tspan = (0.0, 250.0)
52-
prob_disc = DiscreteProblem(u0, tspan, p)
53-
jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345))
54-
55-
sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims)
56-
57-
regular_rate = (out, u, p, t) -> begin
58-
out[1] = p[1] * u[1] * u[3]
59-
out[2] = p[2] * u[2]
60-
out[3] = p[3] * u[3]
61-
end
62-
regular_c = (dc, u, p, t, counts, mark) -> begin
63-
dc .= 0
64-
dc[1] = -counts[1]
65-
dc[2] = counts[1] - counts[2]
66-
dc[3] = counts[2] - counts[3]
67-
dc[4] = counts[3]
68-
end
69-
rj = RegularJump(regular_rate, regular_c, 3)
70-
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
71-
72-
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)
73-
74-
t_points = 0:1.0:250.0
75-
mean_direct_R = [mean(sol_direct[i](t)[4] for i in 1:Nsims) for t in t_points]
76-
mean_implicit_R = [mean(sol_implicit[i](t)[4] for i in 1:Nsims) for t in t_points]
77-
78-
max_error_implicit = maximum(abs.(mean_direct_R .- mean_implicit_R))
79-
# @test max_error_implicit < 0.01 * mean(mean_direct_R)
80-
# end
35+
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0)

0 commit comments

Comments
 (0)