Skip to content

Commit 0f1bcf5

Browse files
refactor
1 parent a76c015 commit 0f1bcf5

File tree

1 file changed

+160
-159
lines changed

1 file changed

+160
-159
lines changed

src/simple_regular_solve.jl

Lines changed: 160 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,169 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
5050
interp = DiffEqBase.ConstantInterpolation(t, u))
5151
end
5252

53-
# Define the SimpleImplicitTauLeaping algorithm
5453
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
5554
epsilon::Float64 # Error control parameter
5655
nc::Int # Critical reaction threshold
5756
nstiff::Float64 # Stiffness threshold for switching
5857
delta::Float64 # Partial equilibrium threshold
5958
end
6059

61-
# Default constructor
6260
SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) =
6361
SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta)
6462

65-
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed = nothing)
66-
# Boilerplate setup
63+
# Compute stoichiometry matrix from c function
64+
function compute_stoichiometry(c, u, numjumps, p, t)
65+
nu = zeros(Int, length(u), numjumps)
66+
for j in 1:numjumps
67+
counts = zeros(numjumps)
68+
counts[j] = 1
69+
du = similar(u)
70+
c(du, u, p, t, counts, nothing)
71+
nu[:, j] = round.(Int, du)
72+
end
73+
return nu
74+
end
75+
76+
# Detect reversible reaction pairs
77+
function find_reversible_pairs(nu)
78+
pairs = Vector{Tuple{Int,Int}}()
79+
for j in 1:size(nu, 2)
80+
for k in (j+1):size(nu, 2)
81+
if nu[:, j] == -nu[:, k]
82+
push!(pairs, (j, k))
83+
end
84+
end
85+
end
86+
return pairs
87+
end
88+
89+
# Compute g_i (approximation from Cao et al., 2006)
90+
function compute_gi(u, nu, i, rate, rate_cache, p, t)
91+
max_order = 1.0
92+
for j in 1:size(nu, 2)
93+
if abs(nu[i, j]) > 0
94+
rate(rate_cache, u, p, t)
95+
if rate_cache[j] > 0
96+
order = 1.0
97+
if sum(abs.(nu[:, j])) > abs(nu[i, j])
98+
order = 2.0
99+
end
100+
max_order = max(max_order, order)
101+
end
102+
end
103+
end
104+
return max_order
105+
end
106+
107+
# Tau-selection for explicit method (Equation 8)
108+
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
109+
rate(rate_cache, u, p, t)
110+
mu = zeros(length(u))
111+
sigma2 = zeros(length(u))
112+
tau = Inf
113+
for i in 1:length(u)
114+
for j in 1:size(nu, 2)
115+
mu[i] += nu[i, j] * rate_cache[j]
116+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
117+
end
118+
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
119+
bound = max(epsilon * u[i] / gi, 1.0)
120+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
121+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
122+
tau = min(tau, mu_term, sigma_term)
123+
end
124+
return max(tau, 1e-10)
125+
end
126+
127+
# Partial equilibrium check (Equation 13)
128+
function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
129+
a_plus = rate_cache[j_plus]
130+
a_minus = rate_cache[j_minus]
131+
return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus)
132+
end
133+
134+
# Tau-selection for implicit method (Equation 14)
135+
function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
136+
rate(rate_cache, u, p, t)
137+
mu = zeros(length(u))
138+
sigma2 = zeros(length(u))
139+
non_equilibrium = trues(size(nu, 2))
140+
for (j_plus, j_minus) in equilibrium_pairs
141+
if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
142+
non_equilibrium[j_plus] = false
143+
non_equilibrium[j_minus] = false
144+
end
145+
end
146+
tau = Inf
147+
for i in 1:length(u)
148+
for j in 1:size(nu, 2)
149+
if non_equilibrium[j]
150+
mu[i] += nu[i, j] * rate_cache[j]
151+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
152+
end
153+
end
154+
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
155+
bound = max(epsilon * u[i] / gi, 1.0)
156+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
157+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
158+
tau = min(tau, mu_term, sigma_term)
159+
end
160+
return max(tau, 1e-10)
161+
end
162+
163+
# Identify critical reactions
164+
function identify_critical_reactions(u, rate_cache, nu, nc)
165+
critical = falses(size(nu, 2))
166+
for j in 1:size(nu, 2)
167+
if rate_cache[j] > 0
168+
Lj = Inf
169+
for i in 1:length(u)
170+
if nu[i, j] < 0
171+
Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j])))
172+
end
173+
end
174+
if Lj < nc
175+
critical[j] = true
176+
end
177+
end
178+
end
179+
return critical
180+
end
181+
182+
# Implicit tau-leaping step using NonlinearSolve
183+
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
184+
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
185+
function f(u_new, params)
186+
rate_new = zeros(eltype(u_new), numjumps)
187+
rate(rate_new, u_new, p, t_prev + tau)
188+
residual = u_new - u_prev
189+
for j in 1:numjumps
190+
residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
191+
end
192+
return residual
193+
end
194+
195+
# Initial guess
196+
u_new = copy(u_prev)
197+
198+
# Solve the nonlinear system
199+
prob = NonlinearProblem(f, u_new, nothing)
200+
sol = solve(prob, NewtonRaphson())
201+
202+
# Check for convergence and numerical stability
203+
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
204+
return round.(Int, max.(u_prev, 0.0)) # Revert to previous state
205+
end
206+
207+
return round.(Int, max.(sol.u, 0.0))
208+
end
209+
210+
# Down-shifting condition (Equation 19)
211+
function use_down_shifting(t, tau_im, tau_ex, a0, t_end)
212+
return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
213+
end
214+
215+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing)
67216
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
68217
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
69218
prob = jump_prob.prob
@@ -92,160 +241,12 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
92241
delta = alg.delta
93242
t_end = tspan[2]
94243

95-
# Compute stoichiometry matrix from c function
96-
function compute_stoichiometry(c, u, numjumps)
97-
nu = zeros(Int, length(u), numjumps)
98-
for j in 1:numjumps
99-
counts = zeros(numjumps)
100-
counts[j] = 1
101-
du = similar(u)
102-
c(du, u, p, t[1], counts, nothing)
103-
nu[:, j] = round.(Int, du)
104-
end
105-
return nu
106-
end
107-
nu = compute_stoichiometry(c, u0, numjumps)
244+
# Compute stoichiometry matrix
245+
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
108246

109247
# Detect reversible reaction pairs
110-
function find_reversible_pairs(nu)
111-
pairs = Vector{Tuple{Int,Int}}()
112-
for j in 1:numjumps
113-
for k in (j+1):numjumps
114-
if nu[:, j] == -nu[:, k]
115-
push!(pairs, (j, k))
116-
end
117-
end
118-
end
119-
return pairs
120-
end
121248
equilibrium_pairs = find_reversible_pairs(nu)
122249

123-
# Helper function to compute g_i (approximation from Cao et al., 2006)
124-
function compute_gi(u, nu, i)
125-
max_order = 1.0
126-
for j in 1:numjumps
127-
if abs(nu[i, j]) > 0
128-
rate(rate_cache, u, p, t[end])
129-
if rate_cache[j] > 0
130-
order = 1.0
131-
if sum(abs.(nu[:, j])) > abs(nu[i, j])
132-
order = 2.0
133-
end
134-
max_order = max(max_order, order)
135-
end
136-
end
137-
end
138-
return max_order
139-
end
140-
141-
# Tau-selection for explicit method (Equation 8)
142-
function compute_tau_explicit(u, rate_cache, nu, p, t)
143-
rate(rate_cache, u, p, t)
144-
mu = zeros(length(u))
145-
sigma2 = zeros(length(u))
146-
tau = Inf
147-
for i in 1:length(u)
148-
for j in 1:numjumps
149-
mu[i] += nu[i, j] * rate_cache[j]
150-
sigma2[i] += nu[i, j]^2 * rate_cache[j]
151-
end
152-
gi = compute_gi(u, nu, i)
153-
bound = max(epsilon * u[i] / gi, 1.0)
154-
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
155-
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
156-
tau = min(tau, mu_term, sigma_term)
157-
end
158-
return max(tau, 1e-10)
159-
end
160-
161-
# Partial equilibrium check (Equation 13)
162-
function is_partial_equilibrium(rate_cache, j_plus, j_minus)
163-
a_plus = rate_cache[j_plus]
164-
a_minus = rate_cache[j_minus]
165-
return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus)
166-
end
167-
168-
# Tau-selection for implicit method (Equation 14)
169-
function compute_tau_implicit(u, rate_cache, nu, p, t)
170-
rate(rate_cache, u, p, t)
171-
mu = zeros(length(u))
172-
sigma2 = zeros(length(u))
173-
non_equilibrium = trues(numjumps)
174-
for (j_plus, j_minus) in equilibrium_pairs
175-
if is_partial_equilibrium(rate_cache, j_plus, j_minus)
176-
non_equilibrium[j_plus] = false
177-
non_equilibrium[j_minus] = false
178-
end
179-
end
180-
tau = Inf
181-
for i in 1:length(u)
182-
for j in 1:numjumps
183-
if non_equilibrium[j]
184-
mu[i] += nu[i, j] * rate_cache[j]
185-
sigma2[i] += nu[i, j]^2 * rate_cache[j]
186-
end
187-
end
188-
gi = compute_gi(u, nu, i)
189-
bound = max(epsilon * u[i] / gi, 1.0)
190-
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
191-
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
192-
tau = min(tau, mu_term, sigma_term)
193-
end
194-
return max(tau, 1e-10)
195-
end
196-
197-
# Identify critical reactions
198-
function identify_critical_reactions(u, rate_cache, nu)
199-
critical = falses(numjumps)
200-
for j in 1:numjumps
201-
if rate_cache[j] > 0
202-
Lj = Inf
203-
for i in 1:length(u)
204-
if nu[i, j] < 0
205-
Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j])))
206-
end
207-
end
208-
if Lj < nc
209-
critical[j] = true
210-
end
211-
end
212-
end
213-
return critical
214-
end
215-
216-
# Implicit tau-leaping step using NonlinearSolve
217-
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p)
218-
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
219-
function f(u_new, params)
220-
rate_new = similar(rate_cache, eltype(u_new))
221-
rate(rate_new, u_new, p, t_prev + tau)
222-
residual = u_new - u_prev
223-
for j in 1:numjumps
224-
residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
225-
end
226-
return residual
227-
end
228-
229-
# Initial guess
230-
u_new = copy(u_prev)
231-
232-
# Solve the nonlinear system
233-
prob = NonlinearProblem(f, u_new, nothing)
234-
sol = solve(prob, NewtonRaphson())
235-
236-
# Check for convergence and numerical stability
237-
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
238-
return round.(Int, max.(u_prev, 0.0)) # Revert to previous state
239-
end
240-
241-
return round.(Int, max.(sol.u, 0.0))
242-
end
243-
244-
# Down-shifting condition (Equation 19)
245-
function use_down_shifting(t, tau_im, tau_ex, a0, t_end)
246-
return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
247-
end
248-
249250
# Main simulation loop
250251
while t[end] < t_end
251252
u_prev = u[end]
@@ -255,11 +256,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
255256
rate(rate_cache, u_prev, p, t_prev)
256257

257258
# Identify critical reactions
258-
critical = identify_critical_reactions(u_prev, rate_cache, nu)
259+
critical = identify_critical_reactions(u_prev, rate_cache, nu, nc)
259260

260261
# Compute tau values
261-
tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev)
262-
tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev)
262+
tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
263+
tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
263264

264265
# Compute critical propensity sum
265266
ac0 = sum(rate_cache[critical])
@@ -314,7 +315,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
314315
end
315316
end
316317
if method == :implicit
317-
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p)
318+
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
318319
else
319320
c(du, u_prev, p, t_prev, counts, nothing)
320321
u_new = u_prev + du
@@ -341,7 +342,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
341342
end
342343
end
343344
if method == :implicit && tau > tau_ex
344-
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p)
345+
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
345346
else
346347
c(du, u_prev, p, t_prev, counts, nothing)
347348
u_new = u_prev + du

0 commit comments

Comments
 (0)