@@ -118,33 +118,38 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
118118 maj = jump_prob. massaction_jump
119119 numjumps = get_num_majumps (maj)
120120 # Extract rates
121- rate = (out, u, p, t) -> begin
122- for j in 1 : get_num_majumps (maj)
123- out[j] = evalrxrate (u, j, maj)
121+ rate = jump_prob. regular_jump != = nothing ? jump_prob. regular_jump. rate :
122+ (out, u, p, t) -> begin
123+ for j in 1 : numjumps
124+ out[j] = evalrxrate (u, j, maj)
125+ end
124126 end
125- end
127+ c = jump_prob . regular_jump != = nothing ? jump_prob . regular_jump . c : nothing
126128 u0 = copy (prob. u0)
127129 tspan = prob. tspan
128130 p = prob. p
129131
130- # Initialize output vectors
131- u_out = [copy (u0)]
132- t_out = [tspan[1 ]]
133- rate_cache = zeros (Float64, numjumps)
134- counts = zeros (Int, numjumps)
132+ # Initialize current state and saved history
133+ u_current = copy (u0)
134+ t_current = tspan[1 ]
135+ usave = [copy (u0)]
136+ tsave = [tspan[1 ]]
137+ rate_cache = zeros (float (eltype (u0)), numjumps)
138+ counts = zero (rate_cache)
135139 du = similar (u0)
136140 t_end = tspan[2 ]
137141 epsilon = alg. epsilon
138142
139143 # Extract stoichiometry once from MassActionJump
140- nu = zeros (Int , length (u0), numjumps)
144+ nu = zeros (float ( eltype (u0)) , length (u0), numjumps)
141145 for j in 1 : numjumps
142146 for (spec_idx, stoich) in maj. net_stoch[j]
143147 nu[spec_idx, j] = stoich
144148 end
145149 end
146150 hor = compute_hor (nu)
147151
152+ # Set up saveat_times
148153 saveat_times = nothing
149154 if isnothing (saveat)
150155 saveat_times = Vector {typeof(tspan[1])} ()
@@ -156,10 +161,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
156161
157162 save_idx = 1
158163
159- # Current state for timestepping
160- u_current = copy (u0)
161- t_current = tspan[1 ]
162-
163164 while t_current < t_end
164165 rate (rate_cache, u_current, p, t_current)
165166 tau = compute_tau_explicit (u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin)
@@ -169,9 +170,13 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
169170 end
170171 counts .= pois_rand .(rng, max .(rate_cache * tau, 0.0 ))
171172 du .= 0
172- for j in 1 : numjumps
173- for (spec_idx, stoich) in maj. net_stoch[j]
174- du[spec_idx] += stoich * counts[j]
173+ if c != = nothing
174+ c (du, u_current, p, t_current, counts, nothing )
175+ else
176+ for j in 1 : numjumps
177+ for (spec_idx, stoich) in maj. net_stoch[j]
178+ du[spec_idx] += stoich * counts[j]
179+ end
175180 end
176181 end
177182 u_new = u_current + du
@@ -185,8 +190,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
185190
186191 # Save state if at a saveat time or if saveat is empty
187192 if isempty (saveat_times) || (save_idx <= length (saveat_times) && t_new >= saveat_times[save_idx])
188- push! (u_out, copy ( u_new) )
189- push! (t_out , t_new)
193+ push! (usave, u_new)
194+ push! (tsave , t_new)
190195 if ! isempty (saveat_times) && t_new >= saveat_times[save_idx]
191196 save_idx += 1
192197 end
@@ -196,9 +201,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
196201 t_current = t_new
197202 end
198203
199- sol = DiffEqBase. build_solution (prob, alg, t_out, u_out ,
204+ sol = DiffEqBase. build_solution (prob, alg, tsave, usave ,
200205 calculate_error= false ,
201- interp= DiffEqBase. ConstantInterpolation (t_out, u_out ))
206+ interp= DiffEqBase. ConstantInterpolation (tsave, usave ))
202207 return sol
203208end
204209
0 commit comments