Skip to content

Commit f7ffa4d

Browse files
refactor
1 parent 2e5d82c commit f7ffa4d

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

src/simple_regular_solve.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,19 @@ function compute_tau_implicit(u, rate_cache, nu, p, t, rate)
111111
for i in 1:length(u)
112112
sum_nu_a = 0.0
113113
for j in 1:size(nu, 2)
114-
sum_nu_a += abs(nu[i, j]) * rate_cache[j]
114+
if nu[i, j] < 0 # Only sum negative stoichiometry
115+
sum_nu_a += abs(nu[i, j]) * rate_cache[j]
116+
end
115117
end
116-
if sum_nu_a > 0
117-
tau = min(tau, 1.0 / sum_nu_a)
118+
if sum_nu_a > 0 && u[i] > 0 # Avoid division by zero
119+
tau = min(tau, u[i] / sum_nu_a)
118120
end
119121
end
120122
return tau
121123
end
122124

123125
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
124-
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0
126+
# Nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (k_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0
125127
function f(u_new)
126128
rate_new = zeros(Float64, numjumps)
127129
rate(rate_new, u_new, p, t_prev + tau)
@@ -132,7 +134,7 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate,
132134
return residual
133135
end
134136

135-
# Compute Jacobian using finite differences
137+
# Numerical Jacobian
136138
function compute_jacobian(u_new)
137139
n = length(u_new)
138140
J = zeros(Float64, n, n)
@@ -158,12 +160,12 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate,
158160
end
159161
J = compute_jacobian(u_new)
160162
if abs(det(J)) < 1e-10 # Check for singular Jacobian
161-
return nothing # Signal failure
163+
return nothing
162164
end
163165
delta = J \ F
164166
u_new -= delta
165167
if any(isnan.(u_new)) || any(isinf.(u_new))
166-
return nothing # Signal failure
168+
return nothing
167169
end
168170
end
169171
return nothing # Failed to converge
@@ -219,31 +221,32 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
219221
rate(rate_cache, u_prev, p, t_prev)
220222
tau_prime = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate)
221223
tau_double_prime = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, rate)
222-
tau = min(tau_prime, tau_double_prime / 10.0)
224+
# Cao et al. (2007): Use tau_prime for explicit, tau_double_prime for implicit
225+
use_implicit = false
226+
tau = tau_prime # Default to explicit
227+
if tau_double_prime < tau_prime && any(u_prev .< 10) # Implicit if populations are low
228+
tau = tau_double_prime
229+
use_implicit = true
230+
end
223231
tau = max(tau, dtmin)
224232
tau = min(tau, t_end - t_prev)
225233
if !isempty(saveat_times)
226234
if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx]
227235
tau = saveat_times[save_idx] - t_prev
228236
end
229237
end
230-
counts .= counts .= pois_rand.((rng,), max.(rate_cache * tau, 0.0))
238+
counts .= pois_rand.((rng,), max.(rate_cache * tau, 0.0))
231239
c(du, u_prev, p, t_prev, counts, nothing)
232240
u_new = u_prev + du
233-
if tau_prime <= tau_double_prime / 10.0
234-
# Explicit update
235-
if any(u_new .< 0)
236-
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
237-
tau /= 2
238-
continue
239-
end
240-
else
241+
if use_implicit
241242
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
242243
if u_new === nothing || any(u_new .< 0)
243-
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
244-
tau /= 2
244+
tau /= 2 # Halve tau if implicit fails or produces negative populations
245245
continue
246246
end
247+
elseif any(u_new .< 0)
248+
tau /= 2 # Halve tau if explicit produces negative populations
249+
continue
247250
end
248251
u_new = max.(u_new, 0)
249252
push!(u, u_new)

test/regular_jumps.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using JumpProcesses, DiffEqBase
22
using Test, LinearAlgebra, Statistics
3-
using StableRNGs, Plots
3+
using StableRNGs
44
rng = StableRNG(12345)
55

66
Nsims = 10
@@ -30,7 +30,6 @@ Nsims = 10
3030
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
3131

3232
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims)
33-
plot(sol_implicit)
3433

3534
t_points = 0:1.0:250.0
3635
mean_direct_S = [mean(sol_direct[i](t)[1] for i in 1:Nsims) for t in t_points]

0 commit comments

Comments
 (0)