Skip to content

Commit 48d699e

Browse files
refactor
1 parent d302a47 commit 48d699e

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

src/simple_regular_solve.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,19 @@ function compute_tau_implicit(u, rate_cache, nu, p, t, rate)
100100
for i in 1:length(u)
101101
sum_nu_a = 0.0
102102
for j in 1:size(nu, 2)
103-
sum_nu_a += abs(nu[i, j]) * rate_cache[j]
103+
if nu[i, j] < 0 # Only sum negative stoichiometry
104+
sum_nu_a += abs(nu[i, j]) * rate_cache[j]
105+
end
104106
end
105-
if sum_nu_a > 0
106-
tau = min(tau, 1.0 / sum_nu_a)
107+
if sum_nu_a > 0 && u[i] > 0 # Avoid division by zero
108+
tau = min(tau, u[i] / sum_nu_a)
107109
end
108110
end
109111
return tau
110112
end
111113

112114
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
113-
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0
115+
# Nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (k_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0
114116
function f(u_new)
115117
rate_new = zeros(Float64, numjumps)
116118
rate(rate_new, u_new, p, t_prev + tau)
@@ -121,7 +123,7 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate,
121123
return residual
122124
end
123125

124-
# Compute Jacobian using finite differences
126+
# Numerical Jacobian
125127
function compute_jacobian(u_new)
126128
n = length(u_new)
127129
J = zeros(Float64, n, n)
@@ -147,12 +149,12 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate,
147149
end
148150
J = compute_jacobian(u_new)
149151
if abs(det(J)) < 1e-10 # Check for singular Jacobian
150-
return nothing # Signal failure
152+
return nothing
151153
end
152154
delta = J \ F
153155
u_new -= delta
154156
if any(isnan.(u_new)) || any(isinf.(u_new))
155-
return nothing # Signal failure
157+
return nothing
156158
end
157159
end
158160
return nothing # Failed to converge
@@ -208,31 +210,32 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
208210
rate(rate_cache, u_prev, p, t_prev)
209211
tau_prime = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate)
210212
tau_double_prime = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, rate)
211-
tau = min(tau_prime, tau_double_prime / 10.0)
213+
# Cao et al. (2007): Use tau_prime for explicit, tau_double_prime for implicit
214+
use_implicit = false
215+
tau = tau_prime # Default to explicit
216+
if tau_double_prime < tau_prime && any(u_prev .< 10) # Implicit if populations are low
217+
tau = tau_double_prime
218+
use_implicit = true
219+
end
212220
tau = max(tau, dtmin)
213221
tau = min(tau, t_end - t_prev)
214222
if !isempty(saveat_times)
215223
if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx]
216224
tau = saveat_times[save_idx] - t_prev
217225
end
218226
end
219-
counts .= counts .= pois_rand.((rng,), max.(rate_cache * tau, 0.0))
227+
counts .= pois_rand.((rng,), max.(rate_cache * tau, 0.0))
220228
c(du, u_prev, p, t_prev, counts, nothing)
221229
u_new = u_prev + du
222-
if tau_prime <= tau_double_prime / 10.0
223-
# Explicit update
224-
if any(u_new .< 0)
225-
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
226-
tau /= 2
227-
continue
228-
end
229-
else
230+
if use_implicit
230231
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
231232
if u_new === nothing || any(u_new .< 0)
232-
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
233-
tau /= 2
233+
tau /= 2 # Halve tau if implicit fails or produces negative populations
234234
continue
235235
end
236+
elseif any(u_new .< 0)
237+
tau /= 2 # Halve tau if explicit produces negative populations
238+
continue
236239
end
237240
u_new = max.(u_new, 0)
238241
push!(u, u_new)

test/regular_jumps.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using JumpProcesses, DiffEqBase
22
using Test, LinearAlgebra, Statistics
3-
using StableRNGs, Plots
3+
using StableRNGs
44
rng = StableRNG(12345)
55

66
Nsims = 10
@@ -30,7 +30,6 @@ Nsims = 10
3030
jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
3131

3232
sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims)
33-
plot(sol_implicit)
3433

3534
t_points = 0:1.0:250.0
3635
mean_direct_S = [mean(sol_direct[i](t)[1] for i in 1:Nsims) for t in t_points]

0 commit comments

Comments
 (0)