Skip to content

Commit 0eb637a

Browse files
refactor
1 parent 9994a08 commit 0eb637a

File tree

1 file changed

+160
-159
lines changed

1 file changed

+160
-159
lines changed

src/simple_regular_solve.jl

Lines changed: 160 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,169 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
6161
interp = DiffEqBase.ConstantInterpolation(t, u))
6262
end
6363

64-
# Define the SimpleImplicitTauLeaping algorithm
6564
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
6665
epsilon::Float64 # Error control parameter
6766
nc::Int # Critical reaction threshold
6867
nstiff::Float64 # Stiffness threshold for switching
6968
delta::Float64 # Partial equilibrium threshold
7069
end
7170

72-
# Default constructor
7371
SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) =
7472
SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta)
7573

76-
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed = nothing)
77-
# Boilerplate setup
74+
# Compute stoichiometry matrix from c function
75+
function compute_stoichiometry(c, u, numjumps, p, t)
76+
nu = zeros(Int, length(u), numjumps)
77+
for j in 1:numjumps
78+
counts = zeros(numjumps)
79+
counts[j] = 1
80+
du = similar(u)
81+
c(du, u, p, t, counts, nothing)
82+
nu[:, j] = round.(Int, du)
83+
end
84+
return nu
85+
end
86+
87+
# Detect reversible reaction pairs
88+
function find_reversible_pairs(nu)
89+
pairs = Vector{Tuple{Int,Int}}()
90+
for j in 1:size(nu, 2)
91+
for k in (j+1):size(nu, 2)
92+
if nu[:, j] == -nu[:, k]
93+
push!(pairs, (j, k))
94+
end
95+
end
96+
end
97+
return pairs
98+
end
99+
100+
# Compute g_i (approximation from Cao et al., 2006)
101+
function compute_gi(u, nu, i, rate, rate_cache, p, t)
102+
max_order = 1.0
103+
for j in 1:size(nu, 2)
104+
if abs(nu[i, j]) > 0
105+
rate(rate_cache, u, p, t)
106+
if rate_cache[j] > 0
107+
order = 1.0
108+
if sum(abs.(nu[:, j])) > abs(nu[i, j])
109+
order = 2.0
110+
end
111+
max_order = max(max_order, order)
112+
end
113+
end
114+
end
115+
return max_order
116+
end
117+
118+
# Tau-selection for explicit method (Equation 8)
119+
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
120+
rate(rate_cache, u, p, t)
121+
mu = zeros(length(u))
122+
sigma2 = zeros(length(u))
123+
tau = Inf
124+
for i in 1:length(u)
125+
for j in 1:size(nu, 2)
126+
mu[i] += nu[i, j] * rate_cache[j]
127+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
128+
end
129+
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
130+
bound = max(epsilon * u[i] / gi, 1.0)
131+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
132+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
133+
tau = min(tau, mu_term, sigma_term)
134+
end
135+
return max(tau, 1e-10)
136+
end
137+
138+
# Partial equilibrium check (Equation 13)
139+
function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
140+
a_plus = rate_cache[j_plus]
141+
a_minus = rate_cache[j_minus]
142+
return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus)
143+
end
144+
145+
# Tau-selection for implicit method (Equation 14)
146+
function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
147+
rate(rate_cache, u, p, t)
148+
mu = zeros(length(u))
149+
sigma2 = zeros(length(u))
150+
non_equilibrium = trues(size(nu, 2))
151+
for (j_plus, j_minus) in equilibrium_pairs
152+
if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
153+
non_equilibrium[j_plus] = false
154+
non_equilibrium[j_minus] = false
155+
end
156+
end
157+
tau = Inf
158+
for i in 1:length(u)
159+
for j in 1:size(nu, 2)
160+
if non_equilibrium[j]
161+
mu[i] += nu[i, j] * rate_cache[j]
162+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
163+
end
164+
end
165+
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
166+
bound = max(epsilon * u[i] / gi, 1.0)
167+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
168+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
169+
tau = min(tau, mu_term, sigma_term)
170+
end
171+
return max(tau, 1e-10)
172+
end
173+
174+
# Identify critical reactions
175+
function identify_critical_reactions(u, rate_cache, nu, nc)
176+
critical = falses(size(nu, 2))
177+
for j in 1:size(nu, 2)
178+
if rate_cache[j] > 0
179+
Lj = Inf
180+
for i in 1:length(u)
181+
if nu[i, j] < 0
182+
Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j])))
183+
end
184+
end
185+
if Lj < nc
186+
critical[j] = true
187+
end
188+
end
189+
end
190+
return critical
191+
end
192+
193+
# Implicit tau-leaping step using NonlinearSolve
194+
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
195+
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
196+
function f(u_new, params)
197+
rate_new = zeros(eltype(u_new), numjumps)
198+
rate(rate_new, u_new, p, t_prev + tau)
199+
residual = u_new - u_prev
200+
for j in 1:numjumps
201+
residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
202+
end
203+
return residual
204+
end
205+
206+
# Initial guess
207+
u_new = copy(u_prev)
208+
209+
# Solve the nonlinear system
210+
prob = NonlinearProblem(f, u_new, nothing)
211+
sol = solve(prob, NewtonRaphson())
212+
213+
# Check for convergence and numerical stability
214+
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
215+
return round.(Int, max.(u_prev, 0.0)) # Revert to previous state
216+
end
217+
218+
return round.(Int, max.(sol.u, 0.0))
219+
end
220+
221+
# Down-shifting condition (Equation 19)
222+
function use_down_shifting(t, tau_im, tau_ex, a0, t_end)
223+
return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
224+
end
225+
226+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing)
78227
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
79228
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
80229
prob = jump_prob.prob
@@ -103,160 +252,12 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
103252
delta = alg.delta
104253
t_end = tspan[2]
105254

106-
# Compute stoichiometry matrix from c function
107-
function compute_stoichiometry(c, u, numjumps)
108-
nu = zeros(Int, length(u), numjumps)
109-
for j in 1:numjumps
110-
counts = zeros(numjumps)
111-
counts[j] = 1
112-
du = similar(u)
113-
c(du, u, p, t[1], counts, nothing)
114-
nu[:, j] = round.(Int, du)
115-
end
116-
return nu
117-
end
118-
nu = compute_stoichiometry(c, u0, numjumps)
255+
# Compute stoichiometry matrix
256+
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
119257

120258
# Detect reversible reaction pairs
121-
function find_reversible_pairs(nu)
122-
pairs = Vector{Tuple{Int,Int}}()
123-
for j in 1:numjumps
124-
for k in (j+1):numjumps
125-
if nu[:, j] == -nu[:, k]
126-
push!(pairs, (j, k))
127-
end
128-
end
129-
end
130-
return pairs
131-
end
132259
equilibrium_pairs = find_reversible_pairs(nu)
133260

134-
# Helper function to compute g_i (approximation from Cao et al., 2006)
135-
function compute_gi(u, nu, i)
136-
max_order = 1.0
137-
for j in 1:numjumps
138-
if abs(nu[i, j]) > 0
139-
rate(rate_cache, u, p, t[end])
140-
if rate_cache[j] > 0
141-
order = 1.0
142-
if sum(abs.(nu[:, j])) > abs(nu[i, j])
143-
order = 2.0
144-
end
145-
max_order = max(max_order, order)
146-
end
147-
end
148-
end
149-
return max_order
150-
end
151-
152-
# Tau-selection for explicit method (Equation 8)
153-
function compute_tau_explicit(u, rate_cache, nu, p, t)
154-
rate(rate_cache, u, p, t)
155-
mu = zeros(length(u))
156-
sigma2 = zeros(length(u))
157-
tau = Inf
158-
for i in 1:length(u)
159-
for j in 1:numjumps
160-
mu[i] += nu[i, j] * rate_cache[j]
161-
sigma2[i] += nu[i, j]^2 * rate_cache[j]
162-
end
163-
gi = compute_gi(u, nu, i)
164-
bound = max(epsilon * u[i] / gi, 1.0)
165-
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
166-
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
167-
tau = min(tau, mu_term, sigma_term)
168-
end
169-
return max(tau, 1e-10)
170-
end
171-
172-
# Partial equilibrium check (Equation 13)
173-
function is_partial_equilibrium(rate_cache, j_plus, j_minus)
174-
a_plus = rate_cache[j_plus]
175-
a_minus = rate_cache[j_minus]
176-
return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus)
177-
end
178-
179-
# Tau-selection for implicit method (Equation 14)
180-
function compute_tau_implicit(u, rate_cache, nu, p, t)
181-
rate(rate_cache, u, p, t)
182-
mu = zeros(length(u))
183-
sigma2 = zeros(length(u))
184-
non_equilibrium = trues(numjumps)
185-
for (j_plus, j_minus) in equilibrium_pairs
186-
if is_partial_equilibrium(rate_cache, j_plus, j_minus)
187-
non_equilibrium[j_plus] = false
188-
non_equilibrium[j_minus] = false
189-
end
190-
end
191-
tau = Inf
192-
for i in 1:length(u)
193-
for j in 1:numjumps
194-
if non_equilibrium[j]
195-
mu[i] += nu[i, j] * rate_cache[j]
196-
sigma2[i] += nu[i, j]^2 * rate_cache[j]
197-
end
198-
end
199-
gi = compute_gi(u, nu, i)
200-
bound = max(epsilon * u[i] / gi, 1.0)
201-
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
202-
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
203-
tau = min(tau, mu_term, sigma_term)
204-
end
205-
return max(tau, 1e-10)
206-
end
207-
208-
# Identify critical reactions
209-
function identify_critical_reactions(u, rate_cache, nu)
210-
critical = falses(numjumps)
211-
for j in 1:numjumps
212-
if rate_cache[j] > 0
213-
Lj = Inf
214-
for i in 1:length(u)
215-
if nu[i, j] < 0
216-
Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j])))
217-
end
218-
end
219-
if Lj < nc
220-
critical[j] = true
221-
end
222-
end
223-
end
224-
return critical
225-
end
226-
227-
# Implicit tau-leaping step using NonlinearSolve
228-
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p)
229-
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
230-
function f(u_new, params)
231-
rate_new = similar(rate_cache, eltype(u_new))
232-
rate(rate_new, u_new, p, t_prev + tau)
233-
residual = u_new - u_prev
234-
for j in 1:numjumps
235-
residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
236-
end
237-
return residual
238-
end
239-
240-
# Initial guess
241-
u_new = copy(u_prev)
242-
243-
# Solve the nonlinear system
244-
prob = NonlinearProblem(f, u_new, nothing)
245-
sol = solve(prob, NewtonRaphson())
246-
247-
# Check for convergence and numerical stability
248-
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
249-
return round.(Int, max.(u_prev, 0.0)) # Revert to previous state
250-
end
251-
252-
return round.(Int, max.(sol.u, 0.0))
253-
end
254-
255-
# Down-shifting condition (Equation 19)
256-
function use_down_shifting(t, tau_im, tau_ex, a0, t_end)
257-
return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
258-
end
259-
260261
# Main simulation loop
261262
while t[end] < t_end
262263
u_prev = u[end]
@@ -266,11 +267,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
266267
rate(rate_cache, u_prev, p, t_prev)
267268

268269
# Identify critical reactions
269-
critical = identify_critical_reactions(u_prev, rate_cache, nu)
270+
critical = identify_critical_reactions(u_prev, rate_cache, nu, nc)
270271

271272
# Compute tau values
272-
tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev)
273-
tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev)
273+
tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
274+
tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
274275

275276
# Compute critical propensity sum
276277
ac0 = sum(rate_cache[critical])
@@ -325,7 +326,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
325326
end
326327
end
327328
if method == :implicit
328-
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p)
329+
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
329330
else
330331
c(du, u_prev, p, t_prev, counts, nothing)
331332
u_new = u_prev + du
@@ -352,7 +353,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
352353
end
353354
end
354355
if method == :implicit && tau > tau_ex
355-
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p)
356+
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
356357
else
357358
c(du, u_prev, p, t_prev, counts, nothing)
358359
u_new = u_prev + du

0 commit comments

Comments
 (0)