@@ -127,8 +127,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
127127 tspan = prob. tspan
128128 p = prob. p
129129
130- u = [copy (u0)]
131- t = [tspan[1 ]]
130+ # Initialize output vectors
131+ u_out = [copy (u0)]
132+ t_out = [tspan[1 ]]
132133 rate_cache = zeros (Float64, numjumps)
133134 counts = zeros (Int, numjumps)
134135 du = similar (u0)
@@ -155,16 +156,16 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
155156
156157 save_idx = 1
157158
158- while t[ end ] < t_end
159- u_prev = u[ end ]
160- t_prev = t[ end ]
161- rate (rate_cache, u_prev, p, t_prev)
162- tau = compute_tau_explicit (u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
163- tau = min (tau, t_end - t_prev )
164- if ! isempty (saveat_times )
165- if save_idx <= length (saveat_times) && t_prev + tau > saveat_times[save_idx]
166- tau = saveat_times[save_idx] - t_prev
167- end
159+ # Current state for timestepping
160+ u_current = copy (u0)
161+ t_current = tspan[ 1 ]
162+
163+ while t_current < t_end
164+ rate (rate_cache, u_current, p, t_current )
165+ tau = compute_tau_explicit (u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin )
166+ tau = min ( tau, t_end - t_current)
167+ if ! isempty (saveat_times) && save_idx <= length (saveat_times) && t_current + tau > saveat_times[save_idx]
168+ tau = saveat_times[save_idx] - t_current
168169 end
169170 counts .= pois_rand .(rng, max .(rate_cache * tau, 0.0 ))
170171 du .= 0
@@ -173,35 +174,31 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
173174 du[spec_idx] += stoich * counts[j]
174175 end
175176 end
176- u_new = u_prev + du
177+ u_new = u_current + du
177178 if any (< (0 ), u_new)
178179 # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
179180 tau /= 2
180181 continue
181182 end
182- u_new = max .(u_new, 0 ) # Ensure non-negative states
183- push! (u, u_new)
184- push! (t, t_prev + tau)
185- if ! isempty (saveat_times) && save_idx <= length (saveat_times) && t[end ] >= saveat_times[save_idx]
186- save_idx += 1
183+ u_new = max .(u_new, 0 )
184+ t_new = t_current + tau
185+
186+ # Save state if at a saveat time or if saveat is empty
187+ if isempty (saveat_times) || (save_idx <= length (saveat_times) && t_new >= saveat_times[save_idx])
188+ push! (u_out, copy (u_new))
189+ push! (t_out, t_new)
190+ if ! isempty (saveat_times) && t_new >= saveat_times[save_idx]
191+ save_idx += 1
192+ end
187193 end
188- end
189194
190- # Interpolate to saveat times if specified
191- if ! isempty (saveat_times)
192- t_out = saveat_times
193- u_out = [u[end ]]
194- for t_save in saveat_times
195- idx = findlast (ti -> ti <= t_save, t)
196- push! (u_out, u[idx])
197- end
198- t = t_out
199- u = u_out[2 : end ]
195+ u_current = u_new
196+ t_current = t_new
200197 end
201198
202- sol = DiffEqBase. build_solution (prob, alg, t, u ,
199+ sol = DiffEqBase. build_solution (prob, alg, t_out, u_out ,
203200 calculate_error= false ,
204- interp= DiffEqBase. ConstantInterpolation (t, u ))
201+ interp= DiffEqBase. ConstantInterpolation (t_out, u_out ))
205202 return sol
206203end
207204
0 commit comments