Skip to content

Commit 75a733f

Browse files
added saveat in SimpleAdaptiveTauLeaping
1 parent 683a2c0 commit 75a733f

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

src/simple_regular_solve.jl

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ end
9999

100100
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
101101
seed = nothing,
102-
dtmin = 1e-10)
102+
dtmin = 1e-10,
103+
saveat = nothing)
103104
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
104105
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
105106
prob = jump_prob.prob
@@ -124,40 +125,65 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
124125

125126
# Compute initial stoichiometry and HOR
126127
nu = zeros(Int, length(u0), numjumps)
128+
counts_temp = zeros(Int, numjumps)
127129
for j in 1:numjumps
128-
counts_temp = zeros(numjumps)
130+
fill!(counts_temp, 0)
129131
counts_temp[j] = 1
130132
c(du, u0, p, t[1], counts_temp, nothing)
131133
nu[:, j] = du
132134
end
133-
134135
hor = zeros(Int, size(nu, 2))
135136
for j in 1:size(nu, 2)
136137
hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1
137138
end
138139

140+
saveat_times = isnothing(saveat) ? Float64[] : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat)
141+
save_idx = 1
142+
139143
while t[end] < t_end
140144
u_prev = u[end]
141145
t_prev = t[end]
142146
# Recompute stoichiometry
143147
for j in 1:numjumps
144-
counts_temp = zeros(numjumps)
148+
fill!(counts_temp, 0)
145149
counts_temp[j] = 1
146150
c(du, u_prev, p, t_prev, counts_temp, nothing)
147151
nu[:, j] = du
148152
end
149153
rate(rate_cache, u_prev, p, t_prev)
150154
tau = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate, dtmin)
151155
tau = min(tau, t_end - t_prev)
156+
if !isempty(saveat_times)
157+
if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx]
158+
tau = saveat_times[save_idx] - t_prev
159+
end
160+
end
152161
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
153162
c(du, u_prev, p, t_prev, counts, nothing)
154-
u_new = max.(u_prev + du, 0)
163+
u_new = u_prev + du
155164
if any(u_new .< 0)
165+
# Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468)
156166
tau /= 2
157167
continue
158168
end
169+
u_new = max.(u_new, 0) # Ensure non-negative states
159170
push!(u, u_new)
160171
push!(t, t_prev + tau)
172+
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t[end] >= saveat_times[save_idx]
173+
save_idx += 1
174+
end
175+
end
176+
177+
# Interpolate to saveat times if specified
178+
if !isempty(saveat_times)
179+
t_out = saveat_times
180+
u_out = [u[end]]
181+
for t_save in saveat_times
182+
idx = findlast(ti -> ti <= t_save, t)
183+
push!(u_out, u[idx])
184+
end
185+
t = t_out
186+
u = u_out[2:end]
161187
end
162188

163189
sol = DiffEqBase.build_solution(prob, alg, t, u,

test/regular_jumps.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Nsims = 1000
4848
sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1)
4949

5050
# Solve with SimpleAdaptiveTauLeaping
51-
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims)
51+
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0)
5252

5353
# Compute mean trajectories at t = 0, 1, ..., 250
5454
t_points = 0:1.0:250.0
@@ -106,7 +106,7 @@ end
106106
sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories=Nsims, dt=0.1)
107107

108108
# Solve with SimpleAdaptiveTauLeaping
109-
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims)
109+
sol_adaptive = solve(EnsembleProblem(jump_prob_tau), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat = 1.0)
110110

111111
# Compute mean trajectories at t = 0, 1, ..., 250
112112
t_points = 0:1.0:250.0

0 commit comments

Comments
 (0)