6868
6969SimpleAdaptiveTauLeaping (; epsilon= 0.05 ) = SimpleAdaptiveTauLeaping (epsilon)
7070
71- function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleAdaptiveTauLeaping ; seed= nothing )
71+ function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleAdaptiveTauLeaping ;
72+ seed = nothing ,
73+ dtmin = 1e-10 )
7274 @assert isempty (jump_prob. jump_callback. continuous_callbacks)
7375 @assert isempty (jump_prob. jump_callback. discrete_callbacks)
7476 prob = jump_prob. prob
@@ -97,7 +99,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
9799 u_prev = u[end ]
98100 t_prev = t[end ]
99101 rate (rate_cache, u_prev, p, t_prev)
100- tau = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
102+ tau = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate, dtmin )
101103 tau = min (tau, t_end - t_prev)
102104 counts .= pois_rand .(rng, max .(rate_cache * tau, 0.0 ))
103105 c (du, u_prev, p, t_prev, counts, nothing )
@@ -116,16 +118,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
116118 return sol
117119end
118120
119- struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
120- epsilon:: Float64 # Error control parameter
121- nc:: Int # Critical reaction threshold
122- nstiff:: Float64 # Stiffness threshold for switching
123- delta:: Float64 # Partial equilibrium threshold
124- end
125-
126- SimpleImplicitTauLeaping (; epsilon= 0.05 , nc= 10 , nstiff= 100.0 , delta= 0.05 ) =
127- SimpleImplicitTauLeaping (epsilon, nc, nstiff, delta)
128-
129121# Compute stoichiometry matrix from c function
130122function compute_stoichiometry (c, u, numjumps, p, t)
131123 nu = zeros (Int, length (u), numjumps)
@@ -139,19 +131,6 @@ function compute_stoichiometry(c, u, numjumps, p, t)
139131 return nu
140132end
141133
142- # Detect reversible reaction pairs
143- function find_reversible_pairs (nu)
144- pairs = Vector {Tuple{Int,Int}} ()
145- for j in 1 : size (nu, 2 )
146- for k in (j+ 1 ): size (nu, 2 )
147- if nu[:, j] == - nu[:, k]
148- push! (pairs, (j, k))
149- end
150- end
151- end
152- return pairs
153- end
154-
155134# Compute g_i (approximation from Cao et al., 2006)
156135function compute_gi (u, nu, i, rate, rate_cache, p, t)
157136 max_order = 1.0
@@ -171,7 +150,7 @@ function compute_gi(u, nu, i, rate, rate_cache, p, t)
171150end
172151
173152# Tau-selection for explicit method (Equation 8)
174- function compute_tau_explicit (u, rate_cache, nu, p, t, epsilon, rate)
153+ function compute_tau_explicit (u, rate_cache, nu, p, t, epsilon, rate, dtmin )
175154 rate (rate_cache, u, p, t)
176155 mu = zeros (length (u))
177156 sigma2 = zeros (length (u))
@@ -187,249 +166,7 @@ function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
187166 sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
188167 tau = min (tau, mu_term, sigma_term)
189168 end
190- return max (tau, 1e-10 )
191- end
192-
193- # Partial equilibrium check (Equation 13)
194- function is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
195- a_plus = rate_cache[j_plus]
196- a_minus = rate_cache[j_minus]
197- return abs (a_plus - a_minus) <= delta * min (a_plus, a_minus)
198- end
199-
200- # Tau-selection for implicit method (Equation 14)
201- function compute_tau_implicit (u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
202- rate (rate_cache, u, p, t)
203- mu = zeros (length (u))
204- sigma2 = zeros (length (u))
205- non_equilibrium = trues (size (nu, 2 ))
206- for (j_plus, j_minus) in equilibrium_pairs
207- if is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
208- non_equilibrium[j_plus] = false
209- non_equilibrium[j_minus] = false
210- end
211- end
212- tau = Inf
213- for i in 1 : length (u)
214- for j in 1 : size (nu, 2 )
215- if non_equilibrium[j]
216- mu[i] += nu[i, j] * rate_cache[j]
217- sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
218- end
219- end
220- gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
221- bound = max (epsilon * u[i] / gi, 1.0 )
222- mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
223- sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
224- tau = min (tau, mu_term, sigma_term)
225- end
226- return max (tau, 1e-10 )
227- end
228-
229- # Identify critical reactions
230- function identify_critical_reactions (u, rate_cache, nu, nc)
231- critical = falses (size (nu, 2 ))
232- for j in 1 : size (nu, 2 )
233- if rate_cache[j] > 0
234- Lj = Inf
235- for i in 1 : length (u)
236- if nu[i, j] < 0
237- Lj = min (Lj, floor (Int, u[i] / abs (nu[i, j])))
238- end
239- end
240- if Lj < nc
241- critical[j] = true
242- end
243- end
244- end
245- return critical
246- end
247-
248- # Implicit tau-leaping step using NonlinearSolve
249- function implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
250- # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
251- function f (u_new, params)
252- rate_new = zeros (eltype (u_new), numjumps)
253- rate (rate_new, u_new, p, t_prev + tau)
254- residual = u_new - u_prev
255- for j in 1 : numjumps
256- residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
257- end
258- return residual
259- end
260-
261- # Initial guess
262- u_new = copy (u_prev)
263-
264- # Solve the nonlinear system
265- prob = NonlinearProblem (f, u_new, nothing )
266- sol = solve (prob, SimpleNewtonRaphson ())
267-
268- # Check for convergence and numerical stability
269- if sol. retcode != ReturnCode. Success || any (isnan .(sol. u)) || any (isinf .(sol. u))
270- return round .(Int, max .(u_prev, 0.0 )) # Revert to previous state
271- end
272-
273- return round .(Int, max .(sol. u, 0.0 ))
274- end
275-
276- # Down-shifting condition (Equation 19)
277- function use_down_shifting (t, tau_im, tau_ex, a0, t_end)
278- return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
279- end
280-
281- function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ; seed= nothing )
282- @assert isempty (jump_prob. jump_callback. continuous_callbacks)
283- @assert isempty (jump_prob. jump_callback. discrete_callbacks)
284- prob = jump_prob. prob
285- rng = DEFAULT_RNG
286- (seed != = nothing ) && seed! (rng, seed)
287-
288- rj = jump_prob. regular_jump
289- rate = rj. rate
290- numjumps = rj. numjumps
291- c = rj. c
292- u0 = copy (prob. u0)
293- tspan = prob. tspan
294- p = prob. p
295-
296- # Initialize storage
297- rate_cache = zeros (Float64, numjumps)
298- counts = zeros (Int, numjumps)
299- du = similar (u0)
300- u = [copy (u0)]
301- t = [tspan[1 ]]
302-
303- # Algorithm parameters
304- epsilon = alg. epsilon
305- nc = alg. nc
306- nstiff = alg. nstiff
307- delta = alg. delta
308- t_end = tspan[2 ]
309-
310- # Compute stoichiometry matrix
311- nu = compute_stoichiometry (c, u0, numjumps, p, t[1 ])
312-
313- # Detect reversible reaction pairs
314- equilibrium_pairs = find_reversible_pairs (nu)
315-
316- # Main simulation loop
317- while t[end ] < t_end
318- u_prev = u[end ]
319- t_prev = t[end ]
320-
321- # Compute propensities
322- rate (rate_cache, u_prev, p, t_prev)
323-
324- # Identify critical reactions
325- critical = identify_critical_reactions (u_prev, rate_cache, nu, nc)
326-
327- # Compute tau values
328- tau_ex = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
329- tau_im = compute_tau_implicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
330-
331- # Compute critical propensity sum
332- ac0 = sum (rate_cache[critical])
333- tau2 = ac0 > 0 ? randexp (rng) / ac0 : Inf
334-
335- # Choose method and stepsize
336- a0 = sum (rate_cache)
337- use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && ! use_down_shifting (t_prev, tau_im, tau_ex, a0, t_end)
338- tau1 = use_implicit ? tau_im : tau_ex
339- method = use_implicit ? :implicit : :explicit
340-
341- # Cap tau to prevent large updates
342- tau1 = min (tau1, 1.0 )
343-
344- # Check if tau1 is too small
345- if a0 > 0 && tau1 < 10 / a0
346- # Use SSA for a few steps
347- steps = method == :implicit ? 10 : 100
348- for _ in 1 : steps
349- if t_prev >= t_end
350- break
351- end
352- rate (rate_cache, u_prev, p, t_prev)
353- a0 = sum (rate_cache)
354- if a0 == 0
355- break
356- end
357- tau = randexp (rng) / a0
358- r = rand (rng) * a0
359- cumsum_rate = 0.0
360- for j in 1 : numjumps
361- cumsum_rate += rate_cache[j]
362- if cumsum_rate > r
363- u_prev += nu[:, j]
364- break
365- end
366- end
367- t_prev += tau
368- push! (u, copy (u_prev))
369- push! (t, t_prev)
370- end
371- continue
372- end
373-
374- # Choose stepsize and compute firings
375- if tau2 > tau1
376- tau = min (tau1, t_end - t_prev)
377- counts .= 0
378- for j in 1 : numjumps
379- if ! critical[j]
380- counts[j] = pois_rand (rng, max (rate_cache[j] * tau, 0.0 ))
381- end
382- end
383- if method == :implicit
384- u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
385- else
386- c (du, u_prev, p, t_prev, counts, nothing )
387- u_new = u_prev + du
388- end
389- else
390- tau = min (tau2, t_end - t_prev)
391- counts .= 0
392- if ac0 > 0
393- r = rand (rng) * ac0
394- cumsum_rate = 0.0
395- for j in 1 : numjumps
396- if critical[j]
397- cumsum_rate += rate_cache[j]
398- if cumsum_rate > r
399- counts[j] = 1
400- break
401- end
402- end
403- end
404- end
405- for j in 1 : numjumps
406- if ! critical[j]
407- counts[j] = pois_rand (rng, max (rate_cache[j] * tau, 0.0 ))
408- end
409- end
410- if method == :implicit && tau > tau_ex
411- u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
412- else
413- c (du, u_prev, p, t_prev, counts, nothing )
414- u_new = u_prev + du
415- end
416- end
417-
418- # Check for negative populations
419- if any (u_new .< 0 )
420- tau1 /= 2
421- continue
422- end
423-
424- # Update state and time
425- push! (u, u_new)
426- push! (t, t_prev + tau)
427- end
428-
429- # Build solution
430- sol = DiffEqBase. build_solution (prob, alg, t, u,
431- calculate_error = false ,
432- interp = DiffEqBase. ConstantInterpolation (t, u))
169+ return max (tau, dtmin)
433170end
434171
435172struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
@@ -445,4 +182,4 @@ function EnsembleGPUKernel()
445182 EnsembleGPUKernel (nothing , 0.0 )
446183end
447184
448- export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping
185+ export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping
0 commit comments