Skip to content

Commit 683a2c0

Browse files
refactor
1 parent e5455b6 commit 683a2c0

File tree

2 files changed

+52
-55
lines changed

2 files changed

+52
-55
lines changed

src/simple_regular_solve.jl

Lines changed: 51 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,35 @@ end
6868

6969
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
7070

71+
function compute_gi(u, nu, hor, i)
72+
max_order = 1.0
73+
for j in 1:size(nu, 2)
74+
if abs(nu[i, j]) > 0
75+
max_order = max(max_order, Float64(hor[j]))
76+
end
77+
end
78+
return max_order
79+
end
80+
81+
function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin)
82+
rate(rate_cache, u, p, t)
83+
mu = zeros(length(u))
84+
sigma2 = zeros(length(u))
85+
tau = Inf
86+
for i in 1:length(u)
87+
for j in 1:size(nu, 2)
88+
mu[i] += nu[i, j] * rate_cache[j]
89+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
90+
end
91+
gi = compute_gi(u, nu, hor, i)
92+
bound = max(epsilon * u[i] / gi, 1.0)
93+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
94+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
95+
tau = min(tau, mu_term, sigma_term)
96+
end
97+
return max(tau, dtmin)
98+
end
99+
71100
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
72101
seed = nothing,
73102
dtmin = 1e-10)
@@ -93,17 +122,36 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
93122
t_end = tspan[2]
94123
epsilon = alg.epsilon
95124

96-
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
125+
# Compute initial stoichiometry and HOR
126+
nu = zeros(Int, length(u0), numjumps)
127+
for j in 1:numjumps
128+
counts_temp = zeros(numjumps)
129+
counts_temp[j] = 1
130+
c(du, u0, p, t[1], counts_temp, nothing)
131+
nu[:, j] = du
132+
end
133+
134+
hor = zeros(Int, size(nu, 2))
135+
for j in 1:size(nu, 2)
136+
hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1
137+
end
97138

98139
while t[end] < t_end
99140
u_prev = u[end]
100141
t_prev = t[end]
142+
# Recompute stoichiometry
143+
for j in 1:numjumps
144+
counts_temp = zeros(numjumps)
145+
counts_temp[j] = 1
146+
c(du, u_prev, p, t_prev, counts_temp, nothing)
147+
nu[:, j] = du
148+
end
101149
rate(rate_cache, u_prev, p, t_prev)
102-
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, dtmin)
150+
tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
103151
tau = min(tau, t_end - t_prev)
104152
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
105153
c(du, u_prev, p, t_prev, counts, nothing)
106-
u_new = u_prev + du
154+
u_new = max.(u_prev + du, 0)
107155
if any(u_new .< 0)
108156
tau /= 2
109157
continue
@@ -118,57 +166,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
118166
return sol
119167
end
120168

121-
# Compute stoichiometry matrix from c function
122-
function compute_stoichiometry(c, u, numjumps, p, t)
123-
nu = zeros(Int, length(u), numjumps)
124-
for j in 1:numjumps
125-
counts = zeros(numjumps)
126-
counts[j] = 1
127-
du = similar(u)
128-
c(du, u, p, t, counts, nothing)
129-
nu[:, j] = round.(Int, du)
130-
end
131-
return nu
132-
end
133-
134-
# Compute g_i (approximation from Cao et al., 2006)
135-
function compute_gi(u, nu, i, rate, rate_cache, p, t)
136-
max_order = 1.0
137-
for j in 1:size(nu, 2)
138-
if abs(nu[i, j]) > 0
139-
rate(rate_cache, u, p, t)
140-
if rate_cache[j] > 0
141-
order = 1.0
142-
if sum(abs.(nu[:, j])) > abs(nu[i, j])
143-
order = 2.0
144-
end
145-
max_order = max(max_order, order)
146-
end
147-
end
148-
end
149-
return max_order
150-
end
151-
152-
# Tau-selection for explicit method (Equation 8)
153-
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin)
154-
rate(rate_cache, u, p, t)
155-
mu = zeros(length(u))
156-
sigma2 = zeros(length(u))
157-
tau = Inf
158-
for i in 1:length(u)
159-
for j in 1:size(nu, 2)
160-
mu[i] += nu[i, j] * rate_cache[j]
161-
sigma2[i] += nu[i, j]^2 * rate_cache[j]
162-
end
163-
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
164-
bound = max(epsilon * u[i] / gi, 1.0)
165-
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
166-
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
167-
tau = min(tau, mu_term, sigma_term)
168-
end
169-
return max(tau, dtmin)
170-
end
171-
172169
struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
173170
backend::Backend
174171
cpu_offload::Float64

test/regular_jumps.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test, LinearAlgebra, Statistics
33
using StableRNGs
44
rng = StableRNG(12345)
55

6-
Nsims = 8000
6+
Nsims = 1000
77

88
# SIR model with influx
99
@testset "SIR Model Correctness" begin

0 commit comments

Comments
 (0)