@@ -50,6 +50,301 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
5050 interp = DiffEqBase. ConstantInterpolation (t, u))
5151end
5252
53+ # Define ImplicitTauLeaping algorithm
54+ struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm
55+ epsilon:: Float64 # Error control parameter
56+ nc:: Int # Critical reaction threshold
57+ nstiff:: Int # Stiffness threshold multiplier
58+ delta:: Float64 # Partial equilibrium threshold
59+ end
60+
61+ ImplicitTauLeaping (; epsilon= 0.05 , nc= 10 , nstiff= 100 , delta= 0.05 ) = ImplicitTauLeaping (epsilon, nc, nstiff, delta)
62+
63+ function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: ImplicitTauLeaping ;
64+ seed = nothing ,
65+ dt = error (" dt is required for ImplicitTauLeaping." ),
66+ kwargs... )
67+
68+ # Boilerplate from SimpleTauLeaping
69+ @assert isempty (jump_prob. jump_callback. continuous_callbacks)
70+ @assert isempty (jump_prob. jump_callback. discrete_callbacks)
71+ prob = jump_prob. prob
72+ rng = DEFAULT_RNG
73+ (seed != = nothing ) && seed! (rng, seed)
74+
75+ rj = jump_prob. regular_jump
76+ rate = rj. rate # rate(out, u, p, t)
77+ numjumps = rj. numjumps
78+ c = rj. c # c(dc, u, p, t, counts, mark)
79+ reversible_pairs = get (kwargs, :reversible_pairs , Tuple{Int,Int}[])
80+
81+ if ! isnothing (rj. mark_dist)
82+ error (" Mark distributions are currently not supported in ImplicitTauLeaping" )
83+ end
84+
85+ # Initialize state and buffers
86+ u0 = copy (prob. u0)
87+ p = prob. p
88+ tspan = prob. tspan
89+ state_dim = length (u0)
90+ dt = Float64 (dt)
91+
92+ # Compute stoichiometry matrix
93+ nu = zeros (Int, state_dim, numjumps)
94+ for j in 1 : numjumps
95+ dc = zeros (state_dim)
96+ c (dc, u0, p, 0.0 , [i == j ? 1 : 0 for i in 1 : numjumps], nothing )
97+ nu[:, j] = dc
98+ end
99+
100+ # Initialize solution arrays
101+ n = Int ((tspan[2 ] - tspan[1 ]) / dt) + 1
102+ u = Vector {typeof(u0)} (undef, n)
103+ u[1 ] = u0
104+ t = range (tspan[1 ], tspan[2 ], length= n)
105+
106+ # Buffers for iteration
107+ current_u = copy (u0)
108+ rate_cache = zeros (Float64, numjumps)
109+ counts = zeros (Float64, numjumps)
110+ local_dc = zeros (Float64, state_dim)
111+ I_rs = 1 : state_dim
112+ g = ones (state_dim) # Scaling factor for tau-leaping
113+
114+ function compute_tau_explicit (u, rate, nu, num_jumps, epsilon, g, J_ncr, I_rs, p)
115+ rate_cache = zeros (eltype (u), num_jumps)
116+ rate (rate_cache, u, p, 0.0 )
117+
118+ mu = zeros (eltype (u), length (u))
119+ sigma2 = zeros (eltype (u), length (u))
120+
121+ for i in I_rs
122+ mu[i] = sum (nu[i,j] * rate_cache[j] for j in J_ncr; init= 0.0 )
123+ sigma2[i] = sum (nu[i,j]^ 2 * rate_cache[j] for j in J_ncr; init= 0.0 )
124+ end
125+
126+ tau = Inf
127+ for i in I_rs
128+ denom_mu = max (epsilon * u[i] / g[i], 1.0 )
129+ denom_sigma = denom_mu^ 2
130+ if abs (mu[i]) > 0
131+ tau = min (tau, denom_mu / abs (mu[i]))
132+ end
133+ if sigma2[i] > 0
134+ tau = min (tau, denom_sigma / sigma2[i])
135+ end
136+ end
137+ return tau
138+ end
139+
140+ function compute_tau_implicit (u, rate, nu, num_jumps, epsilon, g, J_necr, I_rs, p)
141+ rate_cache = zeros (eltype (u), num_jumps)
142+ rate (rate_cache, u, p, 0.0 )
143+
144+ mu = zeros (eltype (u), length (u))
145+ sigma2 = zeros (eltype (u), length (u))
146+
147+ for i in I_rs
148+ mu[i] = sum (nu[i,j] * rate_cache[j] for j in J_necr; init= 0.0 )
149+ sigma2[i] = sum (nu[i,j]^ 2 * rate_cache[j] for j in J_necr; init= 0.0 )
150+ end
151+
152+ tau = Inf
153+ for i in I_rs
154+ denom_mu = max (epsilon * u[i] / g[i], 1.0 )
155+ denom_sigma = denom_mu^ 2
156+ if abs (mu[i]) > 0
157+ tau = min (tau, denom_mu / abs (mu[i]))
158+ end
159+ if sigma2[i] > 0
160+ tau = min (tau, denom_sigma / sigma2[i])
161+ end
162+ end
163+ return isinf (tau) ? 1e6 : tau
164+ end
165+
166+ function identify_critical_reactions (u, nu, num_jumps, nc)
167+ L = zeros (Int, num_jumps)
168+ J_critical = Int[]
169+
170+ for j in 1 : num_jumps
171+ min_val = Inf
172+ for i in 1 : length (u)
173+ if nu[i,j] < 0
174+ val = floor (u[i] / abs (nu[i,j]))
175+ min_val = min (min_val, val)
176+ end
177+ end
178+ L[j] = min_val == Inf ? typemax (Int) : Int (min_val)
179+ if L[j] < nc
180+ push! (J_critical, j)
181+ end
182+ end
183+ J_ncr = setdiff (1 : num_jumps, J_critical)
184+ return J_critical, J_ncr
185+ end
186+
187+ function check_partial_equilibrium (rate_cache, reversible_pairs, delta)
188+ J_equilibrium = Int[]
189+ for (j_plus, j_minus) in reversible_pairs
190+ a_plus = rate_cache[j_plus]
191+ a_minus = rate_cache[j_minus]
192+ if abs (a_plus - a_minus) <= delta * min (a_plus, a_minus)
193+ push! (J_equilibrium, j_plus, j_minus)
194+ end
195+ end
196+ return J_equilibrium
197+ end
198+
199+ function newton_solve! (x_new, x, rate, nu, rate_cache, counts, p, t, tau, max_iter= 10 , tol= 1e-6 )
200+ state_dim = length (x)
201+ num_jumps = length (counts)
202+
203+ for iter in 1 : max_iter
204+ rate (rate_cache, x_new, p, t)
205+ rate_cache .*= tau
206+
207+ residual = x_new .- x
208+ for j in 1 : num_jumps
209+ residual .- = nu[:,j] * (counts[j] - rate_cache[j] + tau * rate_cache[j])
210+ end
211+
212+ if norm (residual) < tol
213+ break
214+ end
215+
216+ J = zeros (eltype (x), state_dim, state_dim)
217+ for j in 1 : num_jumps
218+ for i in 1 : state_dim
219+ for k in 1 : state_dim
220+ J[i,k] += nu[i,j] * nu[k,j] * rate_cache[j]
221+ end
222+ end
223+ end
224+ J = I - tau * J
225+
226+ delta_x = J \ residual
227+ x_new .- = delta_x
228+
229+ if norm (delta_x) < tol
230+ break
231+ end
232+ end
233+ return x_new
234+ end
235+
236+ # Main solver loop
237+ for i in 2 : n
238+ tprev = t[i - 1 ]
239+ J_critical, J_ncr = identify_critical_reactions (current_u, nu, numjumps, alg. nc)
240+
241+ rate (rate_cache, current_u, p, tprev)
242+ a0_critical = sum (rate_cache[j] for j in J_critical; init= 0.0 )
243+
244+ J_equilibrium = check_partial_equilibrium (rate_cache, reversible_pairs, alg. delta)
245+ J_necr = setdiff (J_ncr, J_equilibrium)
246+
247+ tau_ex = compute_tau_explicit (current_u, rate, nu, numjumps, alg. epsilon, g, J_ncr, I_rs, p)
248+ tau_im = compute_tau_implicit (current_u, rate, nu, numjumps, alg. epsilon, g, J_necr, I_rs, p)
249+
250+ tau2 = a0_critical > 0 ? - log (rand (rng)) / a0_critical : Inf
251+ use_implicit = tau_im > alg. nstiff * tau_ex
252+ tau1 = use_implicit ? tau_im : tau_ex
253+
254+ if tau1 < 10 / sum (rate_cache; init= 0.0 )
255+ a0 = sum (rate_cache; init= 0.0 )
256+ if a0 > 0
257+ tau = - log (rand (rng)) / a0
258+ r = rand (rng) * a0
259+ cumsum_a = 0.0
260+ jc = 1
261+ for k in 1 : numjumps
262+ cumsum_a += rate_cache[k]
263+ if cumsum_a > r
264+ jc = k
265+ break
266+ end
267+ end
268+ current_u .+ = nu[:,jc]
269+ else
270+ tau = dt
271+ end
272+ else
273+ tau = min (tau1, tau2, dt)
274+ if tau == tau2
275+ if a0_critical > 0
276+ r = rand (rng) * a0_critical
277+ cumsum_a = 0.0
278+ jc = ! isempty (J_critical) ? J_critical[1 ] : 1
279+ for k in J_critical
280+ cumsum_a += rate_cache[k]
281+ if cumsum_a > r
282+ jc = k
283+ break
284+ end
285+ end
286+ counts .= 0
287+ counts[jc] = 1
288+ if use_implicit && tau > tau_ex
289+ for k in J_ncr
290+ counts[k] = pois_rand (rng, rate_cache[k] * tau)
291+ end
292+ c (local_dc, current_u, p, tprev, counts, nothing )
293+ current_u .= newton_solve! (current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau)
294+ else
295+ for k in J_ncr
296+ counts[k] = pois_rand (rng, rate_cache[k] * tau)
297+ end
298+ c (local_dc, current_u, p, tprev, counts, nothing )
299+ current_u .+ = local_dc
300+ end
301+ else
302+ tau = tau1
303+ if use_implicit
304+ for k in 1 : numjumps
305+ counts[k] = pois_rand (rng, rate_cache[k] * tau)
306+ end
307+ c (local_dc, current_u, p, tprev, counts, nothing )
308+ current_u .= newton_solve! (current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau)
309+ else
310+ for k in 1 : numjumps
311+ counts[k] = pois_rand (rng, rate_cache[k] * tau)
312+ end
313+ c (local_dc, current_u, p, tprev, counts, nothing )
314+ current_u .+ = local_dc
315+ end
316+ end
317+ else
318+ counts .= 0
319+ if use_implicit
320+ for k in J_ncr
321+ counts[k] = pois_rand (rng, rate_cache[k] * tau)
322+ end
323+ c (local_dc, current_u, p, tprev, counts, nothing )
324+ current_u .= newton_solve! (current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau)
325+ else
326+ for k in J_ncr
327+ counts[k] = pois_rand (rng, rate_cache[k] * tau)
328+ end
329+ c (local_dc, current_u, p, tprev, counts, nothing )
330+ current_u .+ = local_dc
331+ end
332+ end
333+ end
334+
335+ if any (current_u .< 0 )
336+ tau1 /= 2
337+ continue
338+ end
339+
340+ u[i] = copy (current_u)
341+ end
342+
343+ sol = DiffEqBase. build_solution (prob, alg, t, u,
344+ calculate_error = false ,
345+ interp = DiffEqBase. ConstantInterpolation (t, u))
346+ end
347+
53348struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
54349 backend:: Backend
55350 cpu_offload:: Float64
@@ -63,14 +358,4 @@ function EnsembleGPUKernel()
63358 EnsembleGPUKernel (nothing , 0.0 )
64359end
65360
66- # Define ImplicitTauLeaping algorithm
67- struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm
68- epsilon:: Float64 # Error control parameter
69- nc:: Int # Critical reaction threshold
70- nstiff:: Int # Stiffness threshold multiplier
71- delta:: Float64 # Partial equilibrium threshold
72- end
73-
74- ImplicitTauLeaping (; epsilon= 0.05 , nc= 10 , nstiff= 100 , delta= 0.05 ) = ImplicitTauLeaping (epsilon, nc, nstiff, delta)
75-
76361export SimpleTauLeaping, EnsembleGPUKernel, ImplicitTauLeaping
0 commit comments