@@ -50,20 +50,169 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
5050 interp = DiffEqBase. ConstantInterpolation (t, u))
5151end
5252
53- # Define the SimpleImplicitTauLeaping algorithm
5453struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
5554 epsilon:: Float64 # Error control parameter
5655 nc:: Int # Critical reaction threshold
5756 nstiff:: Float64 # Stiffness threshold for switching
5857 delta:: Float64 # Partial equilibrium threshold
5958end
6059
61- # Default constructor
6260SimpleImplicitTauLeaping (; epsilon= 0.05 , nc= 10 , nstiff= 100.0 , delta= 0.05 ) =
6361 SimpleImplicitTauLeaping (epsilon, nc, nstiff, delta)
6462
65- function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ; seed = nothing )
66- # Boilerplate setup
63+ # Compute stoichiometry matrix from c function
64+ function compute_stoichiometry (c, u, numjumps, p, t)
65+ nu = zeros (Int, length (u), numjumps)
66+ for j in 1 : numjumps
67+ counts = zeros (numjumps)
68+ counts[j] = 1
69+ du = similar (u)
70+ c (du, u, p, t, counts, nothing )
71+ nu[:, j] = round .(Int, du)
72+ end
73+ return nu
74+ end
75+
76+ # Detect reversible reaction pairs
77+ function find_reversible_pairs (nu)
78+ pairs = Vector {Tuple{Int,Int}} ()
79+ for j in 1 : size (nu, 2 )
80+ for k in (j+ 1 ): size (nu, 2 )
81+ if nu[:, j] == - nu[:, k]
82+ push! (pairs, (j, k))
83+ end
84+ end
85+ end
86+ return pairs
87+ end
88+
89+ # Compute g_i (approximation from Cao et al., 2006)
90+ function compute_gi (u, nu, i, rate, rate_cache, p, t)
91+ max_order = 1.0
92+ for j in 1 : size (nu, 2 )
93+ if abs (nu[i, j]) > 0
94+ rate (rate_cache, u, p, t)
95+ if rate_cache[j] > 0
96+ order = 1.0
97+ if sum (abs .(nu[:, j])) > abs (nu[i, j])
98+ order = 2.0
99+ end
100+ max_order = max (max_order, order)
101+ end
102+ end
103+ end
104+ return max_order
105+ end
106+
107+ # Tau-selection for explicit method (Equation 8)
108+ function compute_tau_explicit (u, rate_cache, nu, p, t, epsilon, rate)
109+ rate (rate_cache, u, p, t)
110+ mu = zeros (length (u))
111+ sigma2 = zeros (length (u))
112+ tau = Inf
113+ for i in 1 : length (u)
114+ for j in 1 : size (nu, 2 )
115+ mu[i] += nu[i, j] * rate_cache[j]
116+ sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
117+ end
118+ gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
119+ bound = max (epsilon * u[i] / gi, 1.0 )
120+ mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
121+ sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
122+ tau = min (tau, mu_term, sigma_term)
123+ end
124+ return max (tau, 1e-10 )
125+ end
126+
127+ # Partial equilibrium check (Equation 13)
128+ function is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
129+ a_plus = rate_cache[j_plus]
130+ a_minus = rate_cache[j_minus]
131+ return abs (a_plus - a_minus) <= delta * min (a_plus, a_minus)
132+ end
133+
134+ # Tau-selection for implicit method (Equation 14)
135+ function compute_tau_implicit (u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
136+ rate (rate_cache, u, p, t)
137+ mu = zeros (length (u))
138+ sigma2 = zeros (length (u))
139+ non_equilibrium = trues (size (nu, 2 ))
140+ for (j_plus, j_minus) in equilibrium_pairs
141+ if is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
142+ non_equilibrium[j_plus] = false
143+ non_equilibrium[j_minus] = false
144+ end
145+ end
146+ tau = Inf
147+ for i in 1 : length (u)
148+ for j in 1 : size (nu, 2 )
149+ if non_equilibrium[j]
150+ mu[i] += nu[i, j] * rate_cache[j]
151+ sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
152+ end
153+ end
154+ gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
155+ bound = max (epsilon * u[i] / gi, 1.0 )
156+ mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
157+ sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
158+ tau = min (tau, mu_term, sigma_term)
159+ end
160+ return max (tau, 1e-10 )
161+ end
162+
163+ # Identify critical reactions
164+ function identify_critical_reactions (u, rate_cache, nu, nc)
165+ critical = falses (size (nu, 2 ))
166+ for j in 1 : size (nu, 2 )
167+ if rate_cache[j] > 0
168+ Lj = Inf
169+ for i in 1 : length (u)
170+ if nu[i, j] < 0
171+ Lj = min (Lj, floor (Int, u[i] / abs (nu[i, j])))
172+ end
173+ end
174+ if Lj < nc
175+ critical[j] = true
176+ end
177+ end
178+ end
179+ return critical
180+ end
181+
182+ # Implicit tau-leaping step using NonlinearSolve
183+ function implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
184+ # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
185+ function f (u_new, params)
186+ rate_new = zeros (eltype (u_new), numjumps)
187+ rate (rate_new, u_new, p, t_prev + tau)
188+ residual = u_new - u_prev
189+ for j in 1 : numjumps
190+ residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
191+ end
192+ return residual
193+ end
194+
195+ # Initial guess
196+ u_new = copy (u_prev)
197+
198+ # Solve the nonlinear system
199+ prob = NonlinearProblem (f, u_new, nothing )
200+ sol = solve (prob, NewtonRaphson ())
201+
202+ # Check for convergence and numerical stability
203+ if sol. retcode != ReturnCode. Success || any (isnan .(sol. u)) || any (isinf .(sol. u))
204+ return round .(Int, max .(u_prev, 0.0 )) # Revert to previous state
205+ end
206+
207+ return round .(Int, max .(sol. u, 0.0 ))
208+ end
209+
210+ # Down-shifting condition (Equation 19)
211+ function use_down_shifting (t, tau_im, tau_ex, a0, t_end)
212+ return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
213+ end
214+
215+ function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ; seed= nothing )
67216 @assert isempty (jump_prob. jump_callback. continuous_callbacks)
68217 @assert isempty (jump_prob. jump_callback. discrete_callbacks)
69218 prob = jump_prob. prob
@@ -92,160 +241,12 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
92241 delta = alg. delta
93242 t_end = tspan[2 ]
94243
95- # Compute stoichiometry matrix from c function
96- function compute_stoichiometry (c, u, numjumps)
97- nu = zeros (Int, length (u), numjumps)
98- for j in 1 : numjumps
99- counts = zeros (numjumps)
100- counts[j] = 1
101- du = similar (u)
102- c (du, u, p, t[1 ], counts, nothing )
103- nu[:, j] = round .(Int, du)
104- end
105- return nu
106- end
107- nu = compute_stoichiometry (c, u0, numjumps)
244+ # Compute stoichiometry matrix
245+ nu = compute_stoichiometry (c, u0, numjumps, p, t[1 ])
108246
109247 # Detect reversible reaction pairs
110- function find_reversible_pairs (nu)
111- pairs = Vector {Tuple{Int,Int}} ()
112- for j in 1 : numjumps
113- for k in (j+ 1 ): numjumps
114- if nu[:, j] == - nu[:, k]
115- push! (pairs, (j, k))
116- end
117- end
118- end
119- return pairs
120- end
121248 equilibrium_pairs = find_reversible_pairs (nu)
122249
123- # Helper function to compute g_i (approximation from Cao et al., 2006)
124- function compute_gi (u, nu, i)
125- max_order = 1.0
126- for j in 1 : numjumps
127- if abs (nu[i, j]) > 0
128- rate (rate_cache, u, p, t[end ])
129- if rate_cache[j] > 0
130- order = 1.0
131- if sum (abs .(nu[:, j])) > abs (nu[i, j])
132- order = 2.0
133- end
134- max_order = max (max_order, order)
135- end
136- end
137- end
138- return max_order
139- end
140-
141- # Tau-selection for explicit method (Equation 8)
142- function compute_tau_explicit (u, rate_cache, nu, p, t)
143- rate (rate_cache, u, p, t)
144- mu = zeros (length (u))
145- sigma2 = zeros (length (u))
146- tau = Inf
147- for i in 1 : length (u)
148- for j in 1 : numjumps
149- mu[i] += nu[i, j] * rate_cache[j]
150- sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
151- end
152- gi = compute_gi (u, nu, i)
153- bound = max (epsilon * u[i] / gi, 1.0 )
154- mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
155- sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
156- tau = min (tau, mu_term, sigma_term)
157- end
158- return max (tau, 1e-10 )
159- end
160-
161- # Partial equilibrium check (Equation 13)
162- function is_partial_equilibrium (rate_cache, j_plus, j_minus)
163- a_plus = rate_cache[j_plus]
164- a_minus = rate_cache[j_minus]
165- return abs (a_plus - a_minus) <= delta * min (a_plus, a_minus)
166- end
167-
168- # Tau-selection for implicit method (Equation 14)
169- function compute_tau_implicit (u, rate_cache, nu, p, t)
170- rate (rate_cache, u, p, t)
171- mu = zeros (length (u))
172- sigma2 = zeros (length (u))
173- non_equilibrium = trues (numjumps)
174- for (j_plus, j_minus) in equilibrium_pairs
175- if is_partial_equilibrium (rate_cache, j_plus, j_minus)
176- non_equilibrium[j_plus] = false
177- non_equilibrium[j_minus] = false
178- end
179- end
180- tau = Inf
181- for i in 1 : length (u)
182- for j in 1 : numjumps
183- if non_equilibrium[j]
184- mu[i] += nu[i, j] * rate_cache[j]
185- sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
186- end
187- end
188- gi = compute_gi (u, nu, i)
189- bound = max (epsilon * u[i] / gi, 1.0 )
190- mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
191- sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
192- tau = min (tau, mu_term, sigma_term)
193- end
194- return max (tau, 1e-10 )
195- end
196-
197- # Identify critical reactions
198- function identify_critical_reactions (u, rate_cache, nu)
199- critical = falses (numjumps)
200- for j in 1 : numjumps
201- if rate_cache[j] > 0
202- Lj = Inf
203- for i in 1 : length (u)
204- if nu[i, j] < 0
205- Lj = min (Lj, floor (Int, u[i] / abs (nu[i, j])))
206- end
207- end
208- if Lj < nc
209- critical[j] = true
210- end
211- end
212- end
213- return critical
214- end
215-
216- # Implicit tau-leaping step using NonlinearSolve
217- function implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p)
218- # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
219- function f (u_new, params)
220- rate_new = similar (rate_cache, eltype (u_new))
221- rate (rate_new, u_new, p, t_prev + tau)
222- residual = u_new - u_prev
223- for j in 1 : numjumps
224- residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
225- end
226- return residual
227- end
228-
229- # Initial guess
230- u_new = copy (u_prev)
231-
232- # Solve the nonlinear system
233- prob = NonlinearProblem (f, u_new, nothing )
234- sol = solve (prob, NewtonRaphson ())
235-
236- # Check for convergence and numerical stability
237- if sol. retcode != ReturnCode. Success || any (isnan .(sol. u)) || any (isinf .(sol. u))
238- return round .(Int, max .(u_prev, 0.0 )) # Revert to previous state
239- end
240-
241- return round .(Int, max .(sol. u, 0.0 ))
242- end
243-
244- # Down-shifting condition (Equation 19)
245- function use_down_shifting (t, tau_im, tau_ex, a0, t_end)
246- return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
247- end
248-
249250 # Main simulation loop
250251 while t[end ] < t_end
251252 u_prev = u[end ]
@@ -255,11 +256,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
255256 rate (rate_cache, u_prev, p, t_prev)
256257
257258 # Identify critical reactions
258- critical = identify_critical_reactions (u_prev, rate_cache, nu)
259+ critical = identify_critical_reactions (u_prev, rate_cache, nu, nc )
259260
260261 # Compute tau values
261- tau_ex = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev)
262- tau_im = compute_tau_implicit (u_prev, rate_cache, nu, p, t_prev)
262+ tau_ex = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate )
263+ tau_im = compute_tau_implicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta )
263264
264265 # Compute critical propensity sum
265266 ac0 = sum (rate_cache[critical])
@@ -314,7 +315,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
314315 end
315316 end
316317 if method == :implicit
317- u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p)
318+ u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps )
318319 else
319320 c (du, u_prev, p, t_prev, counts, nothing )
320321 u_new = u_prev + du
@@ -341,7 +342,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
341342 end
342343 end
343344 if method == :implicit && tau > tau_ex
344- u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p)
345+ u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps )
345346 else
346347 c (du, u_prev, p, t_prev, counts, nothing )
347348 u_new = u_prev + du
0 commit comments