@@ -128,8 +128,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
128128 tspan = prob. tspan
129129 p = prob. p
130130
131- u = [copy (u0)]
132- t = [tspan[1 ]]
131+ # Initialize output vectors
132+ u_out = [copy (u0)]
133+ t_out = [tspan[1 ]]
133134 rate_cache = zeros (Float64, numjumps)
134135 counts = zeros (Int, numjumps)
135136 du = similar (u0)
@@ -156,16 +157,16 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
156157
157158 save_idx = 1
158159
159- while t[ end ] < t_end
160- u_prev = u[ end ]
161- t_prev = t[ end ]
162- rate (rate_cache, u_prev, p, t_prev)
163- tau = compute_tau_explicit (u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
164- tau = min (tau, t_end - t_prev )
165- if ! isempty (saveat_times )
166- if save_idx <= length (saveat_times) && t_prev + tau > saveat_times[save_idx]
167- tau = saveat_times[save_idx] - t_prev
168- end
160+ # Current state for timestepping
161+ u_current = copy (u0)
162+ t_current = tspan[ 1 ]
163+
164+ while t_current < t_end
165+ rate (rate_cache, u_current, p, t_current )
166+ tau = compute_tau_explicit (u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin )
167+ tau = min ( tau, t_end - t_current)
168+ if ! isempty (saveat_times) && save_idx <= length (saveat_times) && t_current + tau > saveat_times[save_idx]
169+ tau = saveat_times[save_idx] - t_current
169170 end
170171 counts .= pois_rand .(rng, max .(rate_cache * tau, 0.0 ))
171172 du .= 0
@@ -174,35 +175,31 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
174175 du[spec_idx] += stoich * counts[j]
175176 end
176177 end
177- u_new = u_prev + du
178+ u_new = u_current + du
178179 if any (< (0 ), u_new)
179180 # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
180181 tau /= 2
181182 continue
182183 end
183- u_new = max .(u_new, 0 ) # Ensure non-negative states
184- push! (u, u_new)
185- push! (t, t_prev + tau)
186- if ! isempty (saveat_times) && save_idx <= length (saveat_times) && t[end ] >= saveat_times[save_idx]
187- save_idx += 1
184+ u_new = max .(u_new, 0 )
185+ t_new = t_current + tau
186+
187+ # Save state if at a saveat time or if saveat is empty
188+ if isempty (saveat_times) || (save_idx <= length (saveat_times) && t_new >= saveat_times[save_idx])
189+ push! (u_out, copy (u_new))
190+ push! (t_out, t_new)
191+ if ! isempty (saveat_times) && t_new >= saveat_times[save_idx]
192+ save_idx += 1
193+ end
188194 end
189- end
190195
191- # Interpolate to saveat times if specified
192- if ! isempty (saveat_times)
193- t_out = saveat_times
194- u_out = [u[end ]]
195- for t_save in saveat_times
196- idx = findlast (ti -> ti <= t_save, t)
197- push! (u_out, u[idx])
198- end
199- t = t_out
200- u = u_out[2 : end ]
196+ u_current = u_new
197+ t_current = t_new
201198 end
202199
203- sol = DiffEqBase. build_solution (prob, alg, t, u ,
200+ sol = DiffEqBase. build_solution (prob, alg, t_out, u_out ,
204201 calculate_error= false ,
205- interp= DiffEqBase. ConstantInterpolation (t, u ))
202+ interp= DiffEqBase. ConstantInterpolation (t_out, u_out ))
206203 return sol
207204end
208205
0 commit comments