@@ -62,6 +62,376 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
6262 interp = DiffEqBase. ConstantInterpolation (t, u))
6363end
6464
65+ struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm
66+ epsilon:: Float64 # Error control parameter
67+ end
68+
69+ SimpleAdaptiveTauLeaping (; epsilon= 0.05 ) = SimpleAdaptiveTauLeaping (epsilon)
70+
71+ function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleAdaptiveTauLeaping ; seed= nothing )
72+ @assert isempty (jump_prob. jump_callback. continuous_callbacks)
73+ @assert isempty (jump_prob. jump_callback. discrete_callbacks)
74+ prob = jump_prob. prob
75+ rng = DEFAULT_RNG
76+ (seed != = nothing ) && seed! (rng, seed)
77+
78+ rj = jump_prob. regular_jump
79+ rate = rj. rate
80+ numjumps = rj. numjumps
81+ c = rj. c
82+ u0 = copy (prob. u0)
83+ tspan = prob. tspan
84+ p = prob. p
85+
86+ u = [copy (u0)]
87+ t = [tspan[1 ]]
88+ rate_cache = zeros (Float64, numjumps)
89+ counts = zeros (Int, numjumps)
90+ du = similar (u0)
91+ t_end = tspan[2 ]
92+ epsilon = alg. epsilon
93+
94+ nu = compute_stoichiometry (c, u0, numjumps, p, t[1 ])
95+
96+ while t[end ] < t_end
97+ u_prev = u[end ]
98+ t_prev = t[end ]
99+ rate (rate_cache, u_prev, p, t_prev)
100+ tau = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
101+ tau = min (tau, t_end - t_prev)
102+ counts .= pois_rand .(rng, max .(rate_cache * tau, 0.0 ))
103+ c (du, u_prev, p, t_prev, counts, nothing )
104+ u_new = u_prev + du
105+ if any (u_new .< 0 )
106+ tau /= 2
107+ continue
108+ end
109+ push! (u, u_new)
110+ push! (t, t_prev + tau)
111+ end
112+
113+ sol = DiffEqBase. build_solution (prob, alg, t, u,
114+ calculate_error= false ,
115+ interp= DiffEqBase. ConstantInterpolation (t, u))
116+ return sol
117+ end
118+
119+ struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
120+ epsilon:: Float64 # Error control parameter
121+ nc:: Int # Critical reaction threshold
122+ nstiff:: Float64 # Stiffness threshold for switching
123+ delta:: Float64 # Partial equilibrium threshold
124+ end
125+
126+ SimpleImplicitTauLeaping (; epsilon= 0.05 , nc= 10 , nstiff= 100.0 , delta= 0.05 ) =
127+ SimpleImplicitTauLeaping (epsilon, nc, nstiff, delta)
128+
129+ # Compute stoichiometry matrix from c function
130+ function compute_stoichiometry (c, u, numjumps, p, t)
131+ nu = zeros (Int, length (u), numjumps)
132+ for j in 1 : numjumps
133+ counts = zeros (numjumps)
134+ counts[j] = 1
135+ du = similar (u)
136+ c (du, u, p, t, counts, nothing )
137+ nu[:, j] = round .(Int, du)
138+ end
139+ return nu
140+ end
141+
142+ # Detect reversible reaction pairs
143+ function find_reversible_pairs (nu)
144+ pairs = Vector {Tuple{Int,Int}} ()
145+ for j in 1 : size (nu, 2 )
146+ for k in (j+ 1 ): size (nu, 2 )
147+ if nu[:, j] == - nu[:, k]
148+ push! (pairs, (j, k))
149+ end
150+ end
151+ end
152+ return pairs
153+ end
154+
155+ # Compute g_i (approximation from Cao et al., 2006)
156+ function compute_gi (u, nu, i, rate, rate_cache, p, t)
157+ max_order = 1.0
158+ for j in 1 : size (nu, 2 )
159+ if abs (nu[i, j]) > 0
160+ rate (rate_cache, u, p, t)
161+ if rate_cache[j] > 0
162+ order = 1.0
163+ if sum (abs .(nu[:, j])) > abs (nu[i, j])
164+ order = 2.0
165+ end
166+ max_order = max (max_order, order)
167+ end
168+ end
169+ end
170+ return max_order
171+ end
172+
173+ # Tau-selection for explicit method (Equation 8)
174+ function compute_tau_explicit (u, rate_cache, nu, p, t, epsilon, rate)
175+ rate (rate_cache, u, p, t)
176+ mu = zeros (length (u))
177+ sigma2 = zeros (length (u))
178+ tau = Inf
179+ for i in 1 : length (u)
180+ for j in 1 : size (nu, 2 )
181+ mu[i] += nu[i, j] * rate_cache[j]
182+ sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
183+ end
184+ gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
185+ bound = max (epsilon * u[i] / gi, 1.0 )
186+ mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
187+ sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
188+ tau = min (tau, mu_term, sigma_term)
189+ end
190+ return max (tau, 1e-10 )
191+ end
192+
193+ # Partial equilibrium check (Equation 13)
194+ function is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
195+ a_plus = rate_cache[j_plus]
196+ a_minus = rate_cache[j_minus]
197+ return abs (a_plus - a_minus) <= delta * min (a_plus, a_minus)
198+ end
199+
200+ # Tau-selection for implicit method (Equation 14)
201+ function compute_tau_implicit (u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
202+ rate (rate_cache, u, p, t)
203+ mu = zeros (length (u))
204+ sigma2 = zeros (length (u))
205+ non_equilibrium = trues (size (nu, 2 ))
206+ for (j_plus, j_minus) in equilibrium_pairs
207+ if is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
208+ non_equilibrium[j_plus] = false
209+ non_equilibrium[j_minus] = false
210+ end
211+ end
212+ tau = Inf
213+ for i in 1 : length (u)
214+ for j in 1 : size (nu, 2 )
215+ if non_equilibrium[j]
216+ mu[i] += nu[i, j] * rate_cache[j]
217+ sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
218+ end
219+ end
220+ gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
221+ bound = max (epsilon * u[i] / gi, 1.0 )
222+ mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
223+ sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
224+ tau = min (tau, mu_term, sigma_term)
225+ end
226+ return max (tau, 1e-10 )
227+ end
228+
229+ # Identify critical reactions
230+ function identify_critical_reactions (u, rate_cache, nu, nc)
231+ critical = falses (size (nu, 2 ))
232+ for j in 1 : size (nu, 2 )
233+ if rate_cache[j] > 0
234+ Lj = Inf
235+ for i in 1 : length (u)
236+ if nu[i, j] < 0
237+ Lj = min (Lj, floor (Int, u[i] / abs (nu[i, j])))
238+ end
239+ end
240+ if Lj < nc
241+ critical[j] = true
242+ end
243+ end
244+ end
245+ return critical
246+ end
247+
248+ # Implicit tau-leaping step using NonlinearSolve
249+ function implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
250+ # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
251+ function f (u_new, params)
252+ rate_new = zeros (eltype (u_new), numjumps)
253+ rate (rate_new, u_new, p, t_prev + tau)
254+ residual = u_new - u_prev
255+ for j in 1 : numjumps
256+ residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
257+ end
258+ return residual
259+ end
260+
261+ # Initial guess
262+ u_new = copy (u_prev)
263+
264+ # Solve the nonlinear system
265+ prob = NonlinearProblem (f, u_new, nothing )
266+ sol = solve (prob, NewtonRaphson ())
267+
268+ # Check for convergence and numerical stability
269+ if sol. retcode != ReturnCode. Success || any (isnan .(sol. u)) || any (isinf .(sol. u))
270+ return round .(Int, max .(u_prev, 0.0 )) # Revert to previous state
271+ end
272+
273+ return round .(Int, max .(sol. u, 0.0 ))
274+ end
275+
276+ # Down-shifting condition (Equation 19)
277+ function use_down_shifting (t, tau_im, tau_ex, a0, t_end)
278+ return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
279+ end
280+
281+ function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ; seed= nothing )
282+ @assert isempty (jump_prob. jump_callback. continuous_callbacks)
283+ @assert isempty (jump_prob. jump_callback. discrete_callbacks)
284+ prob = jump_prob. prob
285+ rng = DEFAULT_RNG
286+ (seed != = nothing ) && seed! (rng, seed)
287+
288+ rj = jump_prob. regular_jump
289+ rate = rj. rate
290+ numjumps = rj. numjumps
291+ c = rj. c
292+ u0 = copy (prob. u0)
293+ tspan = prob. tspan
294+ p = prob. p
295+
296+ # Initialize storage
297+ rate_cache = zeros (Float64, numjumps)
298+ counts = zeros (Int, numjumps)
299+ du = similar (u0)
300+ u = [copy (u0)]
301+ t = [tspan[1 ]]
302+
303+ # Algorithm parameters
304+ epsilon = alg. epsilon
305+ nc = alg. nc
306+ nstiff = alg. nstiff
307+ delta = alg. delta
308+ t_end = tspan[2 ]
309+
310+ # Compute stoichiometry matrix
311+ nu = compute_stoichiometry (c, u0, numjumps, p, t[1 ])
312+
313+ # Detect reversible reaction pairs
314+ equilibrium_pairs = find_reversible_pairs (nu)
315+
316+ # Main simulation loop
317+ while t[end ] < t_end
318+ u_prev = u[end ]
319+ t_prev = t[end ]
320+
321+ # Compute propensities
322+ rate (rate_cache, u_prev, p, t_prev)
323+
324+ # Identify critical reactions
325+ critical = identify_critical_reactions (u_prev, rate_cache, nu, nc)
326+
327+ # Compute tau values
328+ tau_ex = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
329+ tau_im = compute_tau_implicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
330+
331+ # Compute critical propensity sum
332+ ac0 = sum (rate_cache[critical])
333+ tau2 = ac0 > 0 ? randexp (rng) / ac0 : Inf
334+
335+ # Choose method and stepsize
336+ a0 = sum (rate_cache)
337+ use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && ! use_down_shifting (t_prev, tau_im, tau_ex, a0, t_end)
338+ tau1 = use_implicit ? tau_im : tau_ex
339+ method = use_implicit ? :implicit : :explicit
340+
341+ # Cap tau to prevent large updates
342+ tau1 = min (tau1, 1.0 )
343+
344+ # Check if tau1 is too small
345+ if a0 > 0 && tau1 < 10 / a0
346+ # Use SSA for a few steps
347+ steps = method == :implicit ? 10 : 100
348+ for _ in 1 : steps
349+ if t_prev >= t_end
350+ break
351+ end
352+ rate (rate_cache, u_prev, p, t_prev)
353+ a0 = sum (rate_cache)
354+ if a0 == 0
355+ break
356+ end
357+ tau = randexp (rng) / a0
358+ r = rand (rng) * a0
359+ cumsum_rate = 0.0
360+ for j in 1 : numjumps
361+ cumsum_rate += rate_cache[j]
362+ if cumsum_rate > r
363+ u_prev += nu[:, j]
364+ break
365+ end
366+ end
367+ t_prev += tau
368+ push! (u, copy (u_prev))
369+ push! (t, t_prev)
370+ end
371+ continue
372+ end
373+
374+ # Choose stepsize and compute firings
375+ if tau2 > tau1
376+ tau = min (tau1, t_end - t_prev)
377+ counts .= 0
378+ for j in 1 : numjumps
379+ if ! critical[j]
380+ counts[j] = pois_rand (rng, max (rate_cache[j] * tau, 0.0 ))
381+ end
382+ end
383+ if method == :implicit
384+ u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
385+ else
386+ c (du, u_prev, p, t_prev, counts, nothing )
387+ u_new = u_prev + du
388+ end
389+ else
390+ tau = min (tau2, t_end - t_prev)
391+ counts .= 0
392+ if ac0 > 0
393+ r = rand (rng) * ac0
394+ cumsum_rate = 0.0
395+ for j in 1 : numjumps
396+ if critical[j]
397+ cumsum_rate += rate_cache[j]
398+ if cumsum_rate > r
399+ counts[j] = 1
400+ break
401+ end
402+ end
403+ end
404+ end
405+ for j in 1 : numjumps
406+ if ! critical[j]
407+ counts[j] = pois_rand (rng, max (rate_cache[j] * tau, 0.0 ))
408+ end
409+ end
410+ if method == :implicit && tau > tau_ex
411+ u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
412+ else
413+ c (du, u_prev, p, t_prev, counts, nothing )
414+ u_new = u_prev + du
415+ end
416+ end
417+
418+ # Check for negative populations
419+ if any (u_new .< 0 )
420+ tau1 /= 2
421+ continue
422+ end
423+
424+ # Update state and time
425+ push! (u, u_new)
426+ push! (t, t_prev + tau)
427+ end
428+
429+ # Build solution
430+ sol = DiffEqBase. build_solution (prob, alg, t, u,
431+ calculate_error = false ,
432+ interp = DiffEqBase. ConstantInterpolation (t, u))
433+ end
434+
65435struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
66436 backend:: Backend
67437 cpu_offload:: Float64
74444function EnsembleGPUKernel ()
75445 EnsembleGPUKernel (nothing , 0.0 )
76446end
447+
448+ export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping
0 commit comments