Skip to content

Commit 6d3d900

Browse files
saveat optimization
1 parent 10f4ce3 commit 6d3d900

File tree

1 file changed

+28
-31
lines changed

1 file changed

+28
-31
lines changed

src/simple_regular_solve.jl

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
127127
tspan = prob.tspan
128128
p = prob.p
129129

130-
u = [copy(u0)]
131-
t = [tspan[1]]
130+
# Initialize output vectors
131+
u_out = [copy(u0)]
132+
t_out = [tspan[1]]
132133
rate_cache = zeros(Float64, numjumps)
133134
counts = zeros(Int, numjumps)
134135
du = similar(u0)
@@ -155,16 +156,16 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
155156

156157
save_idx = 1
157158

158-
while t[end] < t_end
159-
u_prev = u[end]
160-
t_prev = t[end]
161-
rate(rate_cache, u_prev, p, t_prev)
162-
tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
163-
tau = min(tau, t_end - t_prev)
164-
if !isempty(saveat_times)
165-
if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx]
166-
tau = saveat_times[save_idx] - t_prev
167-
end
159+
# Current state for timestepping
160+
u_current = copy(u0)
161+
t_current = tspan[1]
162+
163+
while t_current < t_end
164+
rate(rate_cache, u_current, p, t_current)
165+
tau = compute_tau_explicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin)
166+
tau = min(tau, t_end - t_current)
167+
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx]
168+
tau = saveat_times[save_idx] - t_current
168169
end
169170
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
170171
du .= 0
@@ -173,35 +174,31 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
173174
du[spec_idx] += stoich * counts[j]
174175
end
175176
end
176-
u_new = u_prev + du
177+
u_new = u_current + du
177178
if any(<(0), u_new)
178179
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
179180
tau /= 2
180181
continue
181182
end
182-
u_new = max.(u_new, 0) # Ensure non-negative states
183-
push!(u, u_new)
184-
push!(t, t_prev + tau)
185-
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t[end] >= saveat_times[save_idx]
186-
save_idx += 1
183+
u_new = max.(u_new, 0)
184+
t_new = t_current + tau
185+
186+
# Save state if at a saveat time or if saveat is empty
187+
if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
188+
push!(u_out, copy(u_new))
189+
push!(t_out, t_new)
190+
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
191+
save_idx += 1
192+
end
187193
end
188-
end
189194

190-
# Interpolate to saveat times if specified
191-
if !isempty(saveat_times)
192-
t_out = saveat_times
193-
u_out = [u[end]]
194-
for t_save in saveat_times
195-
idx = findlast(ti -> ti <= t_save, t)
196-
push!(u_out, u[idx])
197-
end
198-
t = t_out
199-
u = u_out[2:end]
195+
u_current = u_new
196+
t_current = t_new
200197
end
201198

202-
sol = DiffEqBase.build_solution(prob, alg, t, u,
199+
sol = DiffEqBase.build_solution(prob, alg, t_out, u_out,
203200
calculate_error=false,
204-
interp=DiffEqBase.ConstantInterpolation(t, u))
201+
interp=DiffEqBase.ConstantInterpolation(t_out, u_out))
205202
return sol
206203
end
207204

0 commit comments

Comments
 (0)