Skip to content

Commit 06735ba

Browse files
refactor
1 parent 6b44daa commit 06735ba

File tree

2 files changed

+52
-55
lines changed

2 files changed

+52
-55
lines changed

src/simple_regular_solve.jl

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,35 @@ end
5656

5757
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
5858

59+
function compute_gi(u, nu, hor, i)
60+
max_order = 1.0
61+
for j in 1:size(nu, 2)
62+
if abs(nu[i, j]) > 0
63+
max_order = max(max_order, Float64(hor[j]))
64+
end
65+
end
66+
return max_order
67+
end
68+
69+
function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin)
70+
rate(rate_cache, u, p, t)
71+
mu = zeros(length(u))
72+
sigma2 = zeros(length(u))
73+
tau = Inf
74+
for i in 1:length(u)
75+
for j in 1:size(nu, 2)
76+
mu[i] += nu[i, j] * rate_cache[j]
77+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
78+
end
79+
gi = compute_gi(u, nu, hor, i)
80+
bound = max(epsilon * u[i] / gi, 1.0)
81+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
82+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
83+
tau = min(tau, mu_term, sigma_term)
84+
end
85+
return max(tau, dtmin)
86+
end
87+
5988
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
6089
seed = nothing,
6190
dtmin = 1e-10)
@@ -81,17 +110,36 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
81110
t_end = tspan[2]
82111
epsilon = alg.epsilon
83112

84-
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
113+
# Compute initial stoichiometry and HOR
114+
nu = zeros(Int, length(u0), numjumps)
115+
for j in 1:numjumps
116+
counts_temp = zeros(numjumps)
117+
counts_temp[j] = 1
118+
c(du, u0, p, t[1], counts_temp, nothing)
119+
nu[:, j] = du
120+
end
121+
122+
hor = zeros(Int, size(nu, 2))
123+
for j in 1:size(nu, 2)
124+
hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1
125+
end
85126

86127
while t[end] < t_end
87128
u_prev = u[end]
88129
t_prev = t[end]
130+
# Recompute stoichiometry
131+
for j in 1:numjumps
132+
counts_temp = zeros(numjumps)
133+
counts_temp[j] = 1
134+
c(du, u_prev, p, t_prev, counts_temp, nothing)
135+
nu[:, j] = du
136+
end
89137
rate(rate_cache, u_prev, p, t_prev)
90-
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, dtmin)
138+
tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
91139
tau = min(tau, t_end - t_prev)
92140
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
93141
c(du, u_prev, p, t_prev, counts, nothing)
94-
u_new = u_prev + du
142+
u_new = max.(u_prev + du, 0)
95143
if any(u_new .< 0)
96144
tau /= 2
97145
continue
@@ -106,57 +154,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
106154
return sol
107155
end
108156

109-
# Compute stoichiometry matrix from c function
110-
function compute_stoichiometry(c, u, numjumps, p, t)
111-
nu = zeros(Int, length(u), numjumps)
112-
for j in 1:numjumps
113-
counts = zeros(numjumps)
114-
counts[j] = 1
115-
du = similar(u)
116-
c(du, u, p, t, counts, nothing)
117-
nu[:, j] = round.(Int, du)
118-
end
119-
return nu
120-
end
121-
122-
# Compute g_i (approximation from Cao et al., 2006)
123-
function compute_gi(u, nu, i, rate, rate_cache, p, t)
124-
max_order = 1.0
125-
for j in 1:size(nu, 2)
126-
if abs(nu[i, j]) > 0
127-
rate(rate_cache, u, p, t)
128-
if rate_cache[j] > 0
129-
order = 1.0
130-
if sum(abs.(nu[:, j])) > abs(nu[i, j])
131-
order = 2.0
132-
end
133-
max_order = max(max_order, order)
134-
end
135-
end
136-
end
137-
return max_order
138-
end
139-
140-
# Tau-selection for explicit method (Equation 8)
141-
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin)
142-
rate(rate_cache, u, p, t)
143-
mu = zeros(length(u))
144-
sigma2 = zeros(length(u))
145-
tau = Inf
146-
for i in 1:length(u)
147-
for j in 1:size(nu, 2)
148-
mu[i] += nu[i, j] * rate_cache[j]
149-
sigma2[i] += nu[i, j]^2 * rate_cache[j]
150-
end
151-
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
152-
bound = max(epsilon * u[i] / gi, 1.0)
153-
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
154-
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
155-
tau = min(tau, mu_term, sigma_term)
156-
end
157-
return max(tau, dtmin)
158-
end
159-
160157
struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
161158
backend::Backend
162159
cpu_offload::Float64

test/regular_jumps.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test, LinearAlgebra, Statistics
33
using StableRNGs
44
rng = StableRNG(12345)
55

6-
Nsims = 8000
6+
Nsims = 1000
77

88
# SIR model with influx
99
@testset "SIR Model Correctness" begin

0 commit comments

Comments
 (0)