Skip to content

Commit 2c03d67

Browse files
refactor
1 parent 6d3d900 commit 2c03d67

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

src/simple_regular_solve.jl

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -118,33 +118,38 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
118118
maj = jump_prob.massaction_jump
119119
numjumps = get_num_majumps(maj)
120120
# Extract rates
121-
rate = (out, u, p, t) -> begin
122-
for j in 1:get_num_majumps(maj)
123-
out[j] = evalrxrate(u, j, maj)
121+
rate = jump_prob.regular_jump !== nothing ? jump_prob.regular_jump.rate :
122+
(out, u, p, t) -> begin
123+
for j in 1:numjumps
124+
out[j] = evalrxrate(u, j, maj)
125+
end
124126
end
125-
end
127+
c = jump_prob.regular_jump !== nothing ? jump_prob.regular_jump.c : nothing
126128
u0 = copy(prob.u0)
127129
tspan = prob.tspan
128130
p = prob.p
129131

130-
# Initialize output vectors
131-
u_out = [copy(u0)]
132-
t_out = [tspan[1]]
133-
rate_cache = zeros(Float64, numjumps)
134-
counts = zeros(Int, numjumps)
132+
# Initialize current state and saved history
133+
u_current = copy(u0)
134+
t_current = tspan[1]
135+
usave = [copy(u0)]
136+
tsave = [tspan[1]]
137+
rate_cache = zeros(float(eltype(u0)), numjumps)
138+
counts = zero(rate_cache)
135139
du = similar(u0)
136140
t_end = tspan[2]
137141
epsilon = alg.epsilon
138142

139143
# Extract stoichiometry once from MassActionJump
140-
nu = zeros(Int, length(u0), numjumps)
144+
nu = zeros(float(eltype(u0)), length(u0), numjumps)
141145
for j in 1:numjumps
142146
for (spec_idx, stoich) in maj.net_stoch[j]
143147
nu[spec_idx, j] = stoich
144148
end
145149
end
146150
hor = compute_hor(nu)
147151

152+
# Set up saveat_times
148153
saveat_times = nothing
149154
if isnothing(saveat)
150155
saveat_times = Vector{typeof(tspan[1])}()
@@ -156,10 +161,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
156161

157162
save_idx = 1
158163

159-
# Current state for timestepping
160-
u_current = copy(u0)
161-
t_current = tspan[1]
162-
163164
while t_current < t_end
164165
rate(rate_cache, u_current, p, t_current)
165166
tau = compute_tau_explicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin)
@@ -169,9 +170,13 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
169170
end
170171
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
171172
du .= 0
172-
for j in 1:numjumps
173-
for (spec_idx, stoich) in maj.net_stoch[j]
174-
du[spec_idx] += stoich * counts[j]
173+
if c !== nothing
174+
c(du, u_current, p, t_current, counts, nothing)
175+
else
176+
for j in 1:numjumps
177+
for (spec_idx, stoich) in maj.net_stoch[j]
178+
du[spec_idx] += stoich * counts[j]
179+
end
175180
end
176181
end
177182
u_new = u_current + du
@@ -185,8 +190,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
185190

186191
# Save state if at a saveat time or if saveat is empty
187192
if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
188-
push!(u_out, copy(u_new))
189-
push!(t_out, t_new)
193+
push!(usave, u_new)
194+
push!(tsave, t_new)
190195
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
191196
save_idx += 1
192197
end
@@ -196,9 +201,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
196201
t_current = t_new
197202
end
198203

199-
sol = DiffEqBase.build_solution(prob, alg, t_out, u_out,
204+
sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
200205
calculate_error=false,
201-
interp=DiffEqBase.ConstantInterpolation(t_out, u_out))
206+
interp=DiffEqBase.ConstantInterpolation(tsave, usave))
202207
return sol
203208
end
204209

0 commit comments

Comments
 (0)