5656
5757SimpleAdaptiveTauLeaping (; epsilon= 0.05 ) = SimpleAdaptiveTauLeaping (epsilon)
5858
59- function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleAdaptiveTauLeaping ; seed= nothing )
59+ function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleAdaptiveTauLeaping ;
60+ seed = nothing ,
61+ dtmin = 1e-10 )
6062 @assert isempty (jump_prob. jump_callback. continuous_callbacks)
6163 @assert isempty (jump_prob. jump_callback. discrete_callbacks)
6264 prob = jump_prob. prob
@@ -85,7 +87,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
8587 u_prev = u[end ]
8688 t_prev = t[end ]
8789 rate (rate_cache, u_prev, p, t_prev)
88- tau = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
90+ tau = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate, dtmin )
8991 tau = min (tau, t_end - t_prev)
9092 counts .= pois_rand .(rng, max .(rate_cache * tau, 0.0 ))
9193 c (du, u_prev, p, t_prev, counts, nothing )
@@ -104,16 +106,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
104106 return sol
105107end
106108
107- struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
108- epsilon:: Float64 # Error control parameter
109- nc:: Int # Critical reaction threshold
110- nstiff:: Float64 # Stiffness threshold for switching
111- delta:: Float64 # Partial equilibrium threshold
112- end
113-
114- SimpleImplicitTauLeaping (; epsilon= 0.05 , nc= 10 , nstiff= 100.0 , delta= 0.05 ) =
115- SimpleImplicitTauLeaping (epsilon, nc, nstiff, delta)
116-
117109# Compute stoichiometry matrix from c function
118110function compute_stoichiometry (c, u, numjumps, p, t)
119111 nu = zeros (Int, length (u), numjumps)
@@ -127,19 +119,6 @@ function compute_stoichiometry(c, u, numjumps, p, t)
127119 return nu
128120end
129121
130- # Detect reversible reaction pairs
131- function find_reversible_pairs (nu)
132- pairs = Vector {Tuple{Int,Int}} ()
133- for j in 1 : size (nu, 2 )
134- for k in (j+ 1 ): size (nu, 2 )
135- if nu[:, j] == - nu[:, k]
136- push! (pairs, (j, k))
137- end
138- end
139- end
140- return pairs
141- end
142-
143122# Compute g_i (approximation from Cao et al., 2006)
144123function compute_gi (u, nu, i, rate, rate_cache, p, t)
145124 max_order = 1.0
@@ -159,7 +138,7 @@ function compute_gi(u, nu, i, rate, rate_cache, p, t)
159138end
160139
161140# Tau-selection for explicit method (Equation 8)
162- function compute_tau_explicit (u, rate_cache, nu, p, t, epsilon, rate)
141+ function compute_tau_explicit (u, rate_cache, nu, p, t, epsilon, rate, dtmin )
163142 rate (rate_cache, u, p, t)
164143 mu = zeros (length (u))
165144 sigma2 = zeros (length (u))
@@ -175,249 +154,7 @@ function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
175154 sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
176155 tau = min (tau, mu_term, sigma_term)
177156 end
178- return max (tau, 1e-10 )
179- end
180-
181- # Partial equilibrium check (Equation 13)
182- function is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
183- a_plus = rate_cache[j_plus]
184- a_minus = rate_cache[j_minus]
185- return abs (a_plus - a_minus) <= delta * min (a_plus, a_minus)
186- end
187-
188- # Tau-selection for implicit method (Equation 14)
189- function compute_tau_implicit (u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
190- rate (rate_cache, u, p, t)
191- mu = zeros (length (u))
192- sigma2 = zeros (length (u))
193- non_equilibrium = trues (size (nu, 2 ))
194- for (j_plus, j_minus) in equilibrium_pairs
195- if is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
196- non_equilibrium[j_plus] = false
197- non_equilibrium[j_minus] = false
198- end
199- end
200- tau = Inf
201- for i in 1 : length (u)
202- for j in 1 : size (nu, 2 )
203- if non_equilibrium[j]
204- mu[i] += nu[i, j] * rate_cache[j]
205- sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
206- end
207- end
208- gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
209- bound = max (epsilon * u[i] / gi, 1.0 )
210- mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
211- sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
212- tau = min (tau, mu_term, sigma_term)
213- end
214- return max (tau, 1e-10 )
215- end
216-
217- # Identify critical reactions
218- function identify_critical_reactions (u, rate_cache, nu, nc)
219- critical = falses (size (nu, 2 ))
220- for j in 1 : size (nu, 2 )
221- if rate_cache[j] > 0
222- Lj = Inf
223- for i in 1 : length (u)
224- if nu[i, j] < 0
225- Lj = min (Lj, floor (Int, u[i] / abs (nu[i, j])))
226- end
227- end
228- if Lj < nc
229- critical[j] = true
230- end
231- end
232- end
233- return critical
234- end
235-
236- # Implicit tau-leaping step using NonlinearSolve
237- function implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
238- # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
239- function f (u_new, params)
240- rate_new = zeros (eltype (u_new), numjumps)
241- rate (rate_new, u_new, p, t_prev + tau)
242- residual = u_new - u_prev
243- for j in 1 : numjumps
244- residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
245- end
246- return residual
247- end
248-
249- # Initial guess
250- u_new = copy (u_prev)
251-
252- # Solve the nonlinear system
253- prob = NonlinearProblem (f, u_new, nothing )
254- sol = solve (prob, SimpleNewtonRaphson ())
255-
256- # Check for convergence and numerical stability
257- if sol. retcode != ReturnCode. Success || any (isnan .(sol. u)) || any (isinf .(sol. u))
258- return round .(Int, max .(u_prev, 0.0 )) # Revert to previous state
259- end
260-
261- return round .(Int, max .(sol. u, 0.0 ))
262- end
263-
264- # Down-shifting condition (Equation 19)
265- function use_down_shifting (t, tau_im, tau_ex, a0, t_end)
266- return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
267- end
268-
269- function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ; seed= nothing )
270- @assert isempty (jump_prob. jump_callback. continuous_callbacks)
271- @assert isempty (jump_prob. jump_callback. discrete_callbacks)
272- prob = jump_prob. prob
273- rng = DEFAULT_RNG
274- (seed != = nothing ) && seed! (rng, seed)
275-
276- rj = jump_prob. regular_jump
277- rate = rj. rate
278- numjumps = rj. numjumps
279- c = rj. c
280- u0 = copy (prob. u0)
281- tspan = prob. tspan
282- p = prob. p
283-
284- # Initialize storage
285- rate_cache = zeros (Float64, numjumps)
286- counts = zeros (Int, numjumps)
287- du = similar (u0)
288- u = [copy (u0)]
289- t = [tspan[1 ]]
290-
291- # Algorithm parameters
292- epsilon = alg. epsilon
293- nc = alg. nc
294- nstiff = alg. nstiff
295- delta = alg. delta
296- t_end = tspan[2 ]
297-
298- # Compute stoichiometry matrix
299- nu = compute_stoichiometry (c, u0, numjumps, p, t[1 ])
300-
301- # Detect reversible reaction pairs
302- equilibrium_pairs = find_reversible_pairs (nu)
303-
304- # Main simulation loop
305- while t[end ] < t_end
306- u_prev = u[end ]
307- t_prev = t[end ]
308-
309- # Compute propensities
310- rate (rate_cache, u_prev, p, t_prev)
311-
312- # Identify critical reactions
313- critical = identify_critical_reactions (u_prev, rate_cache, nu, nc)
314-
315- # Compute tau values
316- tau_ex = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
317- tau_im = compute_tau_implicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
318-
319- # Compute critical propensity sum
320- ac0 = sum (rate_cache[critical])
321- tau2 = ac0 > 0 ? randexp (rng) / ac0 : Inf
322-
323- # Choose method and stepsize
324- a0 = sum (rate_cache)
325- use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && ! use_down_shifting (t_prev, tau_im, tau_ex, a0, t_end)
326- tau1 = use_implicit ? tau_im : tau_ex
327- method = use_implicit ? :implicit : :explicit
328-
329- # Cap tau to prevent large updates
330- tau1 = min (tau1, 1.0 )
331-
332- # Check if tau1 is too small
333- if a0 > 0 && tau1 < 10 / a0
334- # Use SSA for a few steps
335- steps = method == :implicit ? 10 : 100
336- for _ in 1 : steps
337- if t_prev >= t_end
338- break
339- end
340- rate (rate_cache, u_prev, p, t_prev)
341- a0 = sum (rate_cache)
342- if a0 == 0
343- break
344- end
345- tau = randexp (rng) / a0
346- r = rand (rng) * a0
347- cumsum_rate = 0.0
348- for j in 1 : numjumps
349- cumsum_rate += rate_cache[j]
350- if cumsum_rate > r
351- u_prev += nu[:, j]
352- break
353- end
354- end
355- t_prev += tau
356- push! (u, copy (u_prev))
357- push! (t, t_prev)
358- end
359- continue
360- end
361-
362- # Choose stepsize and compute firings
363- if tau2 > tau1
364- tau = min (tau1, t_end - t_prev)
365- counts .= 0
366- for j in 1 : numjumps
367- if ! critical[j]
368- counts[j] = pois_rand (rng, max (rate_cache[j] * tau, 0.0 ))
369- end
370- end
371- if method == :implicit
372- u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
373- else
374- c (du, u_prev, p, t_prev, counts, nothing )
375- u_new = u_prev + du
376- end
377- else
378- tau = min (tau2, t_end - t_prev)
379- counts .= 0
380- if ac0 > 0
381- r = rand (rng) * ac0
382- cumsum_rate = 0.0
383- for j in 1 : numjumps
384- if critical[j]
385- cumsum_rate += rate_cache[j]
386- if cumsum_rate > r
387- counts[j] = 1
388- break
389- end
390- end
391- end
392- end
393- for j in 1 : numjumps
394- if ! critical[j]
395- counts[j] = pois_rand (rng, max (rate_cache[j] * tau, 0.0 ))
396- end
397- end
398- if method == :implicit && tau > tau_ex
399- u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
400- else
401- c (du, u_prev, p, t_prev, counts, nothing )
402- u_new = u_prev + du
403- end
404- end
405-
406- # Check for negative populations
407- if any (u_new .< 0 )
408- tau1 /= 2
409- continue
410- end
411-
412- # Update state and time
413- push! (u, u_new)
414- push! (t, t_prev + tau)
415- end
416-
417- # Build solution
418- sol = DiffEqBase. build_solution (prob, alg, t, u,
419- calculate_error = false ,
420- interp = DiffEqBase. ConstantInterpolation (t, u))
157+ return max (tau, dtmin)
421158end
422159
423160struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
@@ -433,4 +170,4 @@ function EnsembleGPUKernel()
433170 EnsembleGPUKernel (nothing , 0.0 )
434171end
435172
436- export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping
173+ export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping
0 commit comments