Skip to content

Commit 341b370

Browse files
saveat optimization
1 parent c561fa4 commit 341b370

File tree

1 file changed

+28
-31
lines changed

1 file changed

+28
-31
lines changed

src/simple_regular_solve.jl

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
128128
tspan = prob.tspan
129129
p = prob.p
130130

131-
u = [copy(u0)]
132-
t = [tspan[1]]
131+
# Initialize output vectors
132+
u_out = [copy(u0)]
133+
t_out = [tspan[1]]
133134
rate_cache = zeros(Float64, numjumps)
134135
counts = zeros(Int, numjumps)
135136
du = similar(u0)
@@ -156,16 +157,16 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
156157

157158
save_idx = 1
158159

159-
while t[end] < t_end
160-
u_prev = u[end]
161-
t_prev = t[end]
162-
rate(rate_cache, u_prev, p, t_prev)
163-
tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
164-
tau = min(tau, t_end - t_prev)
165-
if !isempty(saveat_times)
166-
if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx]
167-
tau = saveat_times[save_idx] - t_prev
168-
end
160+
# Current state for timestepping
161+
u_current = copy(u0)
162+
t_current = tspan[1]
163+
164+
while t_current < t_end
165+
rate(rate_cache, u_current, p, t_current)
166+
tau = compute_tau_explicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin)
167+
tau = min(tau, t_end - t_current)
168+
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx]
169+
tau = saveat_times[save_idx] - t_current
169170
end
170171
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
171172
du .= 0
@@ -174,35 +175,31 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
174175
du[spec_idx] += stoich * counts[j]
175176
end
176177
end
177-
u_new = u_prev + du
178+
u_new = u_current + du
178179
if any(<(0), u_new)
179180
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
180181
tau /= 2
181182
continue
182183
end
183-
u_new = max.(u_new, 0) # Ensure non-negative states
184-
push!(u, u_new)
185-
push!(t, t_prev + tau)
186-
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t[end] >= saveat_times[save_idx]
187-
save_idx += 1
184+
u_new = max.(u_new, 0)
185+
t_new = t_current + tau
186+
187+
# Save state if at a saveat time or if saveat is empty
188+
if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
189+
push!(u_out, copy(u_new))
190+
push!(t_out, t_new)
191+
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
192+
save_idx += 1
193+
end
188194
end
189-
end
190195

191-
# Interpolate to saveat times if specified
192-
if !isempty(saveat_times)
193-
t_out = saveat_times
194-
u_out = [u[end]]
195-
for t_save in saveat_times
196-
idx = findlast(ti -> ti <= t_save, t)
197-
push!(u_out, u[idx])
198-
end
199-
t = t_out
200-
u = u_out[2:end]
196+
u_current = u_new
197+
t_current = t_new
201198
end
202199

203-
sol = DiffEqBase.build_solution(prob, alg, t, u,
200+
sol = DiffEqBase.build_solution(prob, alg, t_out, u_out,
204201
calculate_error=false,
205-
interp=DiffEqBase.ConstantInterpolation(t, u))
202+
interp=DiffEqBase.ConstantInterpolation(t_out, u_out))
206203
return sol
207204
end
208205

0 commit comments

Comments
 (0)