@@ -61,20 +61,169 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
6161 interp = DiffEqBase. ConstantInterpolation (t, u))
6262end
6363
64- # Define the SimpleImplicitTauLeaping algorithm
6564struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
6665 epsilon:: Float64 # Error control parameter
6766 nc:: Int # Critical reaction threshold
6867 nstiff:: Float64 # Stiffness threshold for switching
6968 delta:: Float64 # Partial equilibrium threshold
7069end
7170
72- # Default constructor
7371SimpleImplicitTauLeaping (; epsilon= 0.05 , nc= 10 , nstiff= 100.0 , delta= 0.05 ) =
7472 SimpleImplicitTauLeaping (epsilon, nc, nstiff, delta)
7573
76- function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ; seed = nothing )
77- # Boilerplate setup
74+ # Compute stoichiometry matrix from c function
75+ function compute_stoichiometry (c, u, numjumps, p, t)
76+ nu = zeros (Int, length (u), numjumps)
77+ for j in 1 : numjumps
78+ counts = zeros (numjumps)
79+ counts[j] = 1
80+ du = similar (u)
81+ c (du, u, p, t, counts, nothing )
82+ nu[:, j] = round .(Int, du)
83+ end
84+ return nu
85+ end
86+
87+ # Detect reversible reaction pairs
88+ function find_reversible_pairs (nu)
89+ pairs = Vector {Tuple{Int,Int}} ()
90+ for j in 1 : size (nu, 2 )
91+ for k in (j+ 1 ): size (nu, 2 )
92+ if nu[:, j] == - nu[:, k]
93+ push! (pairs, (j, k))
94+ end
95+ end
96+ end
97+ return pairs
98+ end
99+
100+ # Compute g_i (approximation from Cao et al., 2006)
101+ function compute_gi (u, nu, i, rate, rate_cache, p, t)
102+ max_order = 1.0
103+ for j in 1 : size (nu, 2 )
104+ if abs (nu[i, j]) > 0
105+ rate (rate_cache, u, p, t)
106+ if rate_cache[j] > 0
107+ order = 1.0
108+ if sum (abs .(nu[:, j])) > abs (nu[i, j])
109+ order = 2.0
110+ end
111+ max_order = max (max_order, order)
112+ end
113+ end
114+ end
115+ return max_order
116+ end
117+
118+ # Tau-selection for explicit method (Equation 8)
119+ function compute_tau_explicit (u, rate_cache, nu, p, t, epsilon, rate)
120+ rate (rate_cache, u, p, t)
121+ mu = zeros (length (u))
122+ sigma2 = zeros (length (u))
123+ tau = Inf
124+ for i in 1 : length (u)
125+ for j in 1 : size (nu, 2 )
126+ mu[i] += nu[i, j] * rate_cache[j]
127+ sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
128+ end
129+ gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
130+ bound = max (epsilon * u[i] / gi, 1.0 )
131+ mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
132+ sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
133+ tau = min (tau, mu_term, sigma_term)
134+ end
135+ return max (tau, 1e-10 )
136+ end
137+
138+ # Partial equilibrium check (Equation 13)
139+ function is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
140+ a_plus = rate_cache[j_plus]
141+ a_minus = rate_cache[j_minus]
142+ return abs (a_plus - a_minus) <= delta * min (a_plus, a_minus)
143+ end
144+
145+ # Tau-selection for implicit method (Equation 14)
146+ function compute_tau_implicit (u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
147+ rate (rate_cache, u, p, t)
148+ mu = zeros (length (u))
149+ sigma2 = zeros (length (u))
150+ non_equilibrium = trues (size (nu, 2 ))
151+ for (j_plus, j_minus) in equilibrium_pairs
152+ if is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
153+ non_equilibrium[j_plus] = false
154+ non_equilibrium[j_minus] = false
155+ end
156+ end
157+ tau = Inf
158+ for i in 1 : length (u)
159+ for j in 1 : size (nu, 2 )
160+ if non_equilibrium[j]
161+ mu[i] += nu[i, j] * rate_cache[j]
162+ sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
163+ end
164+ end
165+ gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
166+ bound = max (epsilon * u[i] / gi, 1.0 )
167+ mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
168+ sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
169+ tau = min (tau, mu_term, sigma_term)
170+ end
171+ return max (tau, 1e-10 )
172+ end
173+
174+ # Identify critical reactions
175+ function identify_critical_reactions (u, rate_cache, nu, nc)
176+ critical = falses (size (nu, 2 ))
177+ for j in 1 : size (nu, 2 )
178+ if rate_cache[j] > 0
179+ Lj = Inf
180+ for i in 1 : length (u)
181+ if nu[i, j] < 0
182+ Lj = min (Lj, floor (Int, u[i] / abs (nu[i, j])))
183+ end
184+ end
185+ if Lj < nc
186+ critical[j] = true
187+ end
188+ end
189+ end
190+ return critical
191+ end
192+
193+ # Implicit tau-leaping step using NonlinearSolve
194+ function implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
195+ # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
196+ function f (u_new, params)
197+ rate_new = zeros (eltype (u_new), numjumps)
198+ rate (rate_new, u_new, p, t_prev + tau)
199+ residual = u_new - u_prev
200+ for j in 1 : numjumps
201+ residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
202+ end
203+ return residual
204+ end
205+
206+ # Initial guess
207+ u_new = copy (u_prev)
208+
209+ # Solve the nonlinear system
210+ prob = NonlinearProblem (f, u_new, nothing )
211+ sol = solve (prob, NewtonRaphson ())
212+
213+ # Check for convergence and numerical stability
214+ if sol. retcode != ReturnCode. Success || any (isnan .(sol. u)) || any (isinf .(sol. u))
215+ return round .(Int, max .(u_prev, 0.0 )) # Revert to previous state
216+ end
217+
218+ return round .(Int, max .(sol. u, 0.0 ))
219+ end
220+
221+ # Down-shifting condition (Equation 19)
222+ function use_down_shifting (t, tau_im, tau_ex, a0, t_end)
223+ return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
224+ end
225+
226+ function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ; seed= nothing )
78227 @assert isempty (jump_prob. jump_callback. continuous_callbacks)
79228 @assert isempty (jump_prob. jump_callback. discrete_callbacks)
80229 prob = jump_prob. prob
@@ -103,160 +252,12 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
103252 delta = alg. delta
104253 t_end = tspan[2 ]
105254
106- # Compute stoichiometry matrix from c function
107- function compute_stoichiometry (c, u, numjumps)
108- nu = zeros (Int, length (u), numjumps)
109- for j in 1 : numjumps
110- counts = zeros (numjumps)
111- counts[j] = 1
112- du = similar (u)
113- c (du, u, p, t[1 ], counts, nothing )
114- nu[:, j] = round .(Int, du)
115- end
116- return nu
117- end
118- nu = compute_stoichiometry (c, u0, numjumps)
255+ # Compute stoichiometry matrix
256+ nu = compute_stoichiometry (c, u0, numjumps, p, t[1 ])
119257
120258 # Detect reversible reaction pairs
121- function find_reversible_pairs (nu)
122- pairs = Vector {Tuple{Int,Int}} ()
123- for j in 1 : numjumps
124- for k in (j+ 1 ): numjumps
125- if nu[:, j] == - nu[:, k]
126- push! (pairs, (j, k))
127- end
128- end
129- end
130- return pairs
131- end
132259 equilibrium_pairs = find_reversible_pairs (nu)
133260
134- # Helper function to compute g_i (approximation from Cao et al., 2006)
135- function compute_gi (u, nu, i)
136- max_order = 1.0
137- for j in 1 : numjumps
138- if abs (nu[i, j]) > 0
139- rate (rate_cache, u, p, t[end ])
140- if rate_cache[j] > 0
141- order = 1.0
142- if sum (abs .(nu[:, j])) > abs (nu[i, j])
143- order = 2.0
144- end
145- max_order = max (max_order, order)
146- end
147- end
148- end
149- return max_order
150- end
151-
152- # Tau-selection for explicit method (Equation 8)
153- function compute_tau_explicit (u, rate_cache, nu, p, t)
154- rate (rate_cache, u, p, t)
155- mu = zeros (length (u))
156- sigma2 = zeros (length (u))
157- tau = Inf
158- for i in 1 : length (u)
159- for j in 1 : numjumps
160- mu[i] += nu[i, j] * rate_cache[j]
161- sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
162- end
163- gi = compute_gi (u, nu, i)
164- bound = max (epsilon * u[i] / gi, 1.0 )
165- mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
166- sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
167- tau = min (tau, mu_term, sigma_term)
168- end
169- return max (tau, 1e-10 )
170- end
171-
172- # Partial equilibrium check (Equation 13)
173- function is_partial_equilibrium (rate_cache, j_plus, j_minus)
174- a_plus = rate_cache[j_plus]
175- a_minus = rate_cache[j_minus]
176- return abs (a_plus - a_minus) <= delta * min (a_plus, a_minus)
177- end
178-
179- # Tau-selection for implicit method (Equation 14)
180- function compute_tau_implicit (u, rate_cache, nu, p, t)
181- rate (rate_cache, u, p, t)
182- mu = zeros (length (u))
183- sigma2 = zeros (length (u))
184- non_equilibrium = trues (numjumps)
185- for (j_plus, j_minus) in equilibrium_pairs
186- if is_partial_equilibrium (rate_cache, j_plus, j_minus)
187- non_equilibrium[j_plus] = false
188- non_equilibrium[j_minus] = false
189- end
190- end
191- tau = Inf
192- for i in 1 : length (u)
193- for j in 1 : numjumps
194- if non_equilibrium[j]
195- mu[i] += nu[i, j] * rate_cache[j]
196- sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
197- end
198- end
199- gi = compute_gi (u, nu, i)
200- bound = max (epsilon * u[i] / gi, 1.0 )
201- mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
202- sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
203- tau = min (tau, mu_term, sigma_term)
204- end
205- return max (tau, 1e-10 )
206- end
207-
208- # Identify critical reactions
209- function identify_critical_reactions (u, rate_cache, nu)
210- critical = falses (numjumps)
211- for j in 1 : numjumps
212- if rate_cache[j] > 0
213- Lj = Inf
214- for i in 1 : length (u)
215- if nu[i, j] < 0
216- Lj = min (Lj, floor (Int, u[i] / abs (nu[i, j])))
217- end
218- end
219- if Lj < nc
220- critical[j] = true
221- end
222- end
223- end
224- return critical
225- end
226-
227- # Implicit tau-leaping step using NonlinearSolve
228- function implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p)
229- # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
230- function f (u_new, params)
231- rate_new = similar (rate_cache, eltype (u_new))
232- rate (rate_new, u_new, p, t_prev + tau)
233- residual = u_new - u_prev
234- for j in 1 : numjumps
235- residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
236- end
237- return residual
238- end
239-
240- # Initial guess
241- u_new = copy (u_prev)
242-
243- # Solve the nonlinear system
244- prob = NonlinearProblem (f, u_new, nothing )
245- sol = solve (prob, NewtonRaphson ())
246-
247- # Check for convergence and numerical stability
248- if sol. retcode != ReturnCode. Success || any (isnan .(sol. u)) || any (isinf .(sol. u))
249- return round .(Int, max .(u_prev, 0.0 )) # Revert to previous state
250- end
251-
252- return round .(Int, max .(sol. u, 0.0 ))
253- end
254-
255- # Down-shifting condition (Equation 19)
256- function use_down_shifting (t, tau_im, tau_ex, a0, t_end)
257- return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
258- end
259-
260261 # Main simulation loop
261262 while t[end ] < t_end
262263 u_prev = u[end ]
@@ -266,11 +267,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
266267 rate (rate_cache, u_prev, p, t_prev)
267268
268269 # Identify critical reactions
269- critical = identify_critical_reactions (u_prev, rate_cache, nu)
270+ critical = identify_critical_reactions (u_prev, rate_cache, nu, nc )
270271
271272 # Compute tau values
272- tau_ex = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev)
273- tau_im = compute_tau_implicit (u_prev, rate_cache, nu, p, t_prev)
273+ tau_ex = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate )
274+ tau_im = compute_tau_implicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta )
274275
275276 # Compute critical propensity sum
276277 ac0 = sum (rate_cache[critical])
@@ -325,7 +326,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
325326 end
326327 end
327328 if method == :implicit
328- u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p)
329+ u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps )
329330 else
330331 c (du, u_prev, p, t_prev, counts, nothing )
331332 u_new = u_prev + du
@@ -352,7 +353,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
352353 end
353354 end
354355 if method == :implicit && tau > tau_ex
355- u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p)
356+ u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps )
356357 else
357358 c (du, u_prev, p, t_prev, counts, nothing )
358359 u_new = u_prev + du
0 commit comments