6969SimpleImplicitTauLeaping (; epsilon= 0.05 ) = SimpleImplicitTauLeaping (epsilon)
7070
7171function compute_hor (nu)
72- hor = zeros (Int , size (nu, 2 ))
72+ hor = zeros (Int64 , size (nu, 2 ))
7373 for j in 1 : size (nu, 2 )
7474 hor[j] = sum (abs .(nu[:, j])) > maximum (abs .(nu[:, j])) ? 2 : 1
7575 end
8888
8989function compute_tau_explicit (u, rate_cache, nu, hor, p, t, epsilon, rate)
9090 rate (rate_cache, u, p, t)
91- mu = zeros (length (u))
92- sigma2 = zeros (length (u))
91+ mu = zeros (Float64, length (u))
92+ sigma2 = zeros (Float64, length (u))
9393 tau = Inf
9494 for i in 1 : length (u)
9595 for j in 1 : size (nu, 2 )
@@ -111,21 +111,20 @@ function compute_tau_implicit(u, rate_cache, nu, p, t, rate)
111111 for i in 1 : length (u)
112112 sum_nu_a = 0.0
113113 for j in 1 : size (nu, 2 )
114- if nu[i, j] < 0 # Only sum negative stoichiometry
114+ if nu[i, j] < 0
115115 sum_nu_a += abs (nu[i, j]) * rate_cache[j]
116116 end
117117 end
118- if sum_nu_a > 0 && u[i] > 0 # Avoid division by zero
118+ if sum_nu_a > 0 && u[i] > 0
119119 tau = min (tau, u[i] / sum_nu_a)
120120 end
121121 end
122122 return tau
123123end
124124
125125function implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
126- # Nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (k_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0
127- function f (u_new)
128- rate_new = zeros (Float64, numjumps)
126+ function f (u_new, p)
127+ rate_new = zeros (eltype (u_new), numjumps)
129128 rate (rate_new, u_new, p, t_prev + tau)
130129 residual = u_new - u_prev
131130 for j in 1 : numjumps
@@ -134,41 +133,14 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate,
134133 return residual
135134 end
136135
137- # Numerical Jacobian
138- function compute_jacobian (u_new)
139- n = length (u_new)
140- J = zeros (Float64, n, n)
141- h = 1e-6
142- f_u = f (u_new)
143- for j in 1 : n
144- u_pert = copy (u_new)
145- u_pert[j] += h
146- f_pert = f (u_pert)
147- J[:, j] = (f_pert - f_u) / h
148- end
149- return J
150- end
136+ u_new = float .(u_prev + sum (nu[:, j] * counts[j] for j in 1 : numjumps))
137+ prob = NonlinearProblem {false} (f, u_new, p)
138+ sol = solve (prob, SimpleNewtonRaphson (), abstol= 1e-6 , maxiters= 100 )
151139
152- # Inline Newton-Raphson
153- u_new = float .(u_prev + sum (nu[:, j] * counts[j] for j in 1 : numjumps)) # Initial guess: explicit step
154- tol = 1e-6
155- maxiters = 100
156- for iter in 1 : maxiters
157- F = f (u_new)
158- if norm (F) < tol
159- return round .(Int, max .(u_new, 0.0 )) # Converged
160- end
161- J = compute_jacobian (u_new)
162- if abs (det (J)) < 1e-10 # Check for singular Jacobian
163- return nothing
164- end
165- delta = J \ F
166- u_new -= delta
167- if any (isnan .(u_new)) || any (isinf .(u_new))
168- return nothing
169- end
140+ if sol. retcode != ReturnCode. Success || any (isnan .(sol. u)) || any (isinf .(sol. u))
141+ return nothing
170142 end
171- return nothing # Failed to converge
143+ return round .(Int64, max .(sol . u, 0.0 ))
172144end
173145
174146function DiffEqBase. solve (jump_prob:: JumpProblem , alg:: SimpleImplicitTauLeaping ; seed= nothing , dtmin= 1e-10 , saveat= nothing )
@@ -211,7 +183,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
211183 while t[end ] < t_end
212184 u_prev = u[end ]
213185 t_prev = t[end ]
214- # Recompute stoichiometry
215186 for j in 1 : numjumps
216187 fill! (counts_temp, 0 )
217188 counts_temp[j] = 1
@@ -221,11 +192,10 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
221192 rate (rate_cache, u_prev, p, t_prev)
222193 tau_prime = compute_tau_explicit (u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate)
223194 tau_double_prime = compute_tau_implicit (u_prev, rate_cache, nu, p, t_prev, rate)
224- # Cao et al. (2007): Use tau_prime for explicit, tau_double_prime for implicit
225195 use_implicit = false
226- tau = tau_prime # Default to explicit
227- if tau_double_prime < tau_prime && any (u_prev .< 10 ) # Implicit if populations are low
228- tau = tau_double_prime
196+ tau = tau_prime
197+ if any (u_prev .< 10 )
198+ tau = min ( tau_double_prime, tau_prime) # Tighter cap for accuracy
229199 use_implicit = true
230200 end
231201 tau = max (tau, dtmin)
@@ -241,11 +211,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
241211 if use_implicit
242212 u_new = implicit_tau_step (u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
243213 if u_new === nothing || any (u_new .< 0 )
244- tau /= 2 # Halve tau if implicit fails or produces negative populations
214+ tau /= 2
245215 continue
246216 end
247217 elseif any (u_new .< 0 )
248- tau /= 2 # Halve tau if explicit produces negative populations
218+ tau /= 2
249219 continue
250220 end
251221 u_new = max .(u_new, 0 )
0 commit comments