@@ -50,6 +50,376 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
5050 interp = DiffEqBase. ConstantInterpolation (t, u))
5151end
5252
53+ struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm
54+ epsilon:: Float64 # Error control parameter
55+ end
56+
57+ SimpleAdaptiveTauLeaping (; epsilon= 0.05 ) = SimpleAdaptiveTauLeaping (epsilon)
58+
59+ function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleAdaptiveTauLeaping ; seed= nothing )
60+ @assert isempty (jump_prob. jump_callback. continuous_callbacks)
61+ @assert isempty (jump_prob. jump_callback. discrete_callbacks)
62+ prob = jump_prob. prob
63+ rng = DEFAULT_RNG
64+ (seed != = nothing ) && seed! (rng, seed)
65+
66+ rj = jump_prob. regular_jump
67+ rate = rj. rate
68+ numjumps = rj. numjumps
69+ c = rj. c
70+ u0 = copy (prob. u0)
71+ tspan = prob. tspan
72+ p = prob. p
73+
74+ u = [copy (u0)]
75+ t = [tspan[1 ]]
76+ rate_cache = zeros (Float64, numjumps)
77+ counts = zeros (Int, numjumps)
78+ du = similar (u0)
79+ t_end = tspan[2 ]
80+ epsilon = alg. epsilon
81+
82+ nu = compute_stoichiometry (c, u0, numjumps, p, t[1 ])
83+
84+ while t[end ] < t_end
85+ u_prev = u[end ]
86+ t_prev = t[end ]
87+ rate (rate_cache, u_prev, p, t_prev)
88+ tau = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
89+ tau = min (tau, t_end - t_prev)
90+ counts .= pois_rand .(rng, max .(rate_cache * tau, 0.0 ))
91+ c (du, u_prev, p, t_prev, counts, nothing )
92+ u_new = u_prev + du
93+ if any (u_new .< 0 )
94+ tau /= 2
95+ continue
96+ end
97+ push! (u, u_new)
98+ push! (t, t_prev + tau)
99+ end
100+
101+ sol = DiffEqBase. build_solution (prob, alg, t, u,
102+ calculate_error= false ,
103+ interp= DiffEqBase. ConstantInterpolation (t, u))
104+ return sol
105+ end
106+
107+ struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
108+ epsilon:: Float64 # Error control parameter
109+ nc:: Int # Critical reaction threshold
110+ nstiff:: Float64 # Stiffness threshold for switching
111+ delta:: Float64 # Partial equilibrium threshold
112+ end
113+
114+ SimpleImplicitTauLeaping (; epsilon= 0.05 , nc= 10 , nstiff= 100.0 , delta= 0.05 ) =
115+ SimpleImplicitTauLeaping (epsilon, nc, nstiff, delta)
116+
117+ # Compute stoichiometry matrix from c function
118+ function compute_stoichiometry (c, u, numjumps, p, t)
119+ nu = zeros (Int, length (u), numjumps)
120+ for j in 1 : numjumps
121+ counts = zeros (numjumps)
122+ counts[j] = 1
123+ du = similar (u)
124+ c (du, u, p, t, counts, nothing )
125+ nu[:, j] = round .(Int, du)
126+ end
127+ return nu
128+ end
129+
130+ # Detect reversible reaction pairs
131+ function find_reversible_pairs (nu)
132+ pairs = Vector {Tuple{Int,Int}} ()
133+ for j in 1 : size (nu, 2 )
134+ for k in (j+ 1 ): size (nu, 2 )
135+ if nu[:, j] == - nu[:, k]
136+ push! (pairs, (j, k))
137+ end
138+ end
139+ end
140+ return pairs
141+ end
142+
143+ # Compute g_i (approximation from Cao et al., 2006)
144+ function compute_gi (u, nu, i, rate, rate_cache, p, t)
145+ max_order = 1.0
146+ for j in 1 : size (nu, 2 )
147+ if abs (nu[i, j]) > 0
148+ rate (rate_cache, u, p, t)
149+ if rate_cache[j] > 0
150+ order = 1.0
151+ if sum (abs .(nu[:, j])) > abs (nu[i, j])
152+ order = 2.0
153+ end
154+ max_order = max (max_order, order)
155+ end
156+ end
157+ end
158+ return max_order
159+ end
160+
161+ # Tau-selection for explicit method (Equation 8)
162+ function compute_tau_explicit (u, rate_cache, nu, p, t, epsilon, rate)
163+ rate (rate_cache, u, p, t)
164+ mu = zeros (length (u))
165+ sigma2 = zeros (length (u))
166+ tau = Inf
167+ for i in 1 : length (u)
168+ for j in 1 : size (nu, 2 )
169+ mu[i] += nu[i, j] * rate_cache[j]
170+ sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
171+ end
172+ gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
173+ bound = max (epsilon * u[i] / gi, 1.0 )
174+ mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
175+ sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
176+ tau = min (tau, mu_term, sigma_term)
177+ end
178+ return max (tau, 1e-10 )
179+ end
180+
181+ # Partial equilibrium check (Equation 13)
182+ function is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
183+ a_plus = rate_cache[j_plus]
184+ a_minus = rate_cache[j_minus]
185+ return abs (a_plus - a_minus) <= delta * min (a_plus, a_minus)
186+ end
187+
188+ # Tau-selection for implicit method (Equation 14)
189+ function compute_tau_implicit (u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
190+ rate (rate_cache, u, p, t)
191+ mu = zeros (length (u))
192+ sigma2 = zeros (length (u))
193+ non_equilibrium = trues (size (nu, 2 ))
194+ for (j_plus, j_minus) in equilibrium_pairs
195+ if is_partial_equilibrium (rate_cache, j_plus, j_minus, delta)
196+ non_equilibrium[j_plus] = false
197+ non_equilibrium[j_minus] = false
198+ end
199+ end
200+ tau = Inf
201+ for i in 1 : length (u)
202+ for j in 1 : size (nu, 2 )
203+ if non_equilibrium[j]
204+ mu[i] += nu[i, j] * rate_cache[j]
205+ sigma2[i] += nu[i, j]^ 2 * rate_cache[j]
206+ end
207+ end
208+ gi = compute_gi (u, nu, i, rate, rate_cache, p, t)
209+ bound = max (epsilon * u[i] / gi, 1.0 )
210+ mu_term = abs (mu[i]) > 0 ? bound / abs (mu[i]) : Inf
211+ sigma_term = sigma2[i] > 0 ? bound^ 2 / sigma2[i] : Inf
212+ tau = min (tau, mu_term, sigma_term)
213+ end
214+ return max (tau, 1e-10 )
215+ end
216+
217+ # Identify critical reactions
218+ function identify_critical_reactions (u, rate_cache, nu, nc)
219+ critical = falses (size (nu, 2 ))
220+ for j in 1 : size (nu, 2 )
221+ if rate_cache[j] > 0
222+ Lj = Inf
223+ for i in 1 : length (u)
224+ if nu[i, j] < 0
225+ Lj = min (Lj, floor (Int, u[i] / abs (nu[i, j])))
226+ end
227+ end
228+ if Lj < nc
229+ critical[j] = true
230+ end
231+ end
232+ end
233+ return critical
234+ end
235+
236+ # Implicit tau-leaping step using NonlinearSolve
237+ function implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
238+ # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
239+ function f (u_new, params)
240+ rate_new = zeros (eltype (u_new), numjumps)
241+ rate (rate_new, u_new, p, t_prev + tau)
242+ residual = u_new - u_prev
243+ for j in 1 : numjumps
244+ residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
245+ end
246+ return residual
247+ end
248+
249+ # Initial guess
250+ u_new = copy (u_prev)
251+
252+ # Solve the nonlinear system
253+ prob = NonlinearProblem (f, u_new, nothing )
254+ sol = solve (prob, NewtonRaphson ())
255+
256+ # Check for convergence and numerical stability
257+ if sol. retcode != ReturnCode. Success || any (isnan .(sol. u)) || any (isinf .(sol. u))
258+ return round .(Int, max .(u_prev, 0.0 )) # Revert to previous state
259+ end
260+
261+ return round .(Int, max .(sol. u, 0.0 ))
262+ end
263+
264+ # Down-shifting condition (Equation 19)
265+ function use_down_shifting (t, tau_im, tau_ex, a0, t_end)
266+ return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
267+ end
268+
269+ function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ; seed= nothing )
270+ @assert isempty (jump_prob. jump_callback. continuous_callbacks)
271+ @assert isempty (jump_prob. jump_callback. discrete_callbacks)
272+ prob = jump_prob. prob
273+ rng = DEFAULT_RNG
274+ (seed != = nothing ) && seed! (rng, seed)
275+
276+ rj = jump_prob. regular_jump
277+ rate = rj. rate
278+ numjumps = rj. numjumps
279+ c = rj. c
280+ u0 = copy (prob. u0)
281+ tspan = prob. tspan
282+ p = prob. p
283+
284+ # Initialize storage
285+ rate_cache = zeros (Float64, numjumps)
286+ counts = zeros (Int, numjumps)
287+ du = similar (u0)
288+ u = [copy (u0)]
289+ t = [tspan[1 ]]
290+
291+ # Algorithm parameters
292+ epsilon = alg. epsilon
293+ nc = alg. nc
294+ nstiff = alg. nstiff
295+ delta = alg. delta
296+ t_end = tspan[2 ]
297+
298+ # Compute stoichiometry matrix
299+ nu = compute_stoichiometry (c, u0, numjumps, p, t[1 ])
300+
301+ # Detect reversible reaction pairs
302+ equilibrium_pairs = find_reversible_pairs (nu)
303+
304+ # Main simulation loop
305+ while t[end ] < t_end
306+ u_prev = u[end ]
307+ t_prev = t[end ]
308+
309+ # Compute propensities
310+ rate (rate_cache, u_prev, p, t_prev)
311+
312+ # Identify critical reactions
313+ critical = identify_critical_reactions (u_prev, rate_cache, nu, nc)
314+
315+ # Compute tau values
316+ tau_ex = compute_tau_explicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
317+ tau_im = compute_tau_implicit (u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
318+
319+ # Compute critical propensity sum
320+ ac0 = sum (rate_cache[critical])
321+ tau2 = ac0 > 0 ? randexp (rng) / ac0 : Inf
322+
323+ # Choose method and stepsize
324+ a0 = sum (rate_cache)
325+ use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && ! use_down_shifting (t_prev, tau_im, tau_ex, a0, t_end)
326+ tau1 = use_implicit ? tau_im : tau_ex
327+ method = use_implicit ? :implicit : :explicit
328+
329+ # Cap tau to prevent large updates
330+ tau1 = min (tau1, 1.0 )
331+
332+ # Check if tau1 is too small
333+ if a0 > 0 && tau1 < 10 / a0
334+ # Use SSA for a few steps
335+ steps = method == :implicit ? 10 : 100
336+ for _ in 1 : steps
337+ if t_prev >= t_end
338+ break
339+ end
340+ rate (rate_cache, u_prev, p, t_prev)
341+ a0 = sum (rate_cache)
342+ if a0 == 0
343+ break
344+ end
345+ tau = randexp (rng) / a0
346+ r = rand (rng) * a0
347+ cumsum_rate = 0.0
348+ for j in 1 : numjumps
349+ cumsum_rate += rate_cache[j]
350+ if cumsum_rate > r
351+ u_prev += nu[:, j]
352+ break
353+ end
354+ end
355+ t_prev += tau
356+ push! (u, copy (u_prev))
357+ push! (t, t_prev)
358+ end
359+ continue
360+ end
361+
362+ # Choose stepsize and compute firings
363+ if tau2 > tau1
364+ tau = min (tau1, t_end - t_prev)
365+ counts .= 0
366+ for j in 1 : numjumps
367+ if ! critical[j]
368+ counts[j] = pois_rand (rng, max (rate_cache[j] * tau, 0.0 ))
369+ end
370+ end
371+ if method == :implicit
372+ u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
373+ else
374+ c (du, u_prev, p, t_prev, counts, nothing )
375+ u_new = u_prev + du
376+ end
377+ else
378+ tau = min (tau2, t_end - t_prev)
379+ counts .= 0
380+ if ac0 > 0
381+ r = rand (rng) * ac0
382+ cumsum_rate = 0.0
383+ for j in 1 : numjumps
384+ if critical[j]
385+ cumsum_rate += rate_cache[j]
386+ if cumsum_rate > r
387+ counts[j] = 1
388+ break
389+ end
390+ end
391+ end
392+ end
393+ for j in 1 : numjumps
394+ if ! critical[j]
395+ counts[j] = pois_rand (rng, max (rate_cache[j] * tau, 0.0 ))
396+ end
397+ end
398+ if method == :implicit && tau > tau_ex
399+ u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
400+ else
401+ c (du, u_prev, p, t_prev, counts, nothing )
402+ u_new = u_prev + du
403+ end
404+ end
405+
406+ # Check for negative populations
407+ if any (u_new .< 0 )
408+ tau1 /= 2
409+ continue
410+ end
411+
412+ # Update state and time
413+ push! (u, u_new)
414+ push! (t, t_prev + tau)
415+ end
416+
417+ # Build solution
418+ sol = DiffEqBase. build_solution (prob, alg, t, u,
419+ calculate_error = false ,
420+ interp = DiffEqBase. ConstantInterpolation (t, u))
421+ end
422+
53423struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
54424 backend:: Backend
55425 cpu_offload:: Float64
@@ -63,4 +433,4 @@ function EnsembleGPUKernel()
63433 EnsembleGPUKernel (nothing , 0.0 )
64434end
65435
66- export SimpleTauLeaping, EnsembleGPUKernel
436+ export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping
0 commit comments