Skip to content

Commit 4699990

Browse files
Implemented SimpleAdaptiveTauLeaping and SimpleImplicitTauLeaping
1 parent 782c430 commit 4699990

File tree

3 files changed

+393
-90
lines changed

3 files changed

+393
-90
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1212
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1313
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
16+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1517
PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
1618
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1719
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

src/simple_regular_solve.jl

Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,376 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
6262
interp = DiffEqBase.ConstantInterpolation(t, u))
6363
end
6464

65+
struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm
66+
epsilon::Float64 # Error control parameter
67+
end
68+
69+
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
70+
71+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing)
72+
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
73+
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
74+
prob = jump_prob.prob
75+
rng = DEFAULT_RNG
76+
(seed !== nothing) && seed!(rng, seed)
77+
78+
rj = jump_prob.regular_jump
79+
rate = rj.rate
80+
numjumps = rj.numjumps
81+
c = rj.c
82+
u0 = copy(prob.u0)
83+
tspan = prob.tspan
84+
p = prob.p
85+
86+
u = [copy(u0)]
87+
t = [tspan[1]]
88+
rate_cache = zeros(Float64, numjumps)
89+
counts = zeros(Int, numjumps)
90+
du = similar(u0)
91+
t_end = tspan[2]
92+
epsilon = alg.epsilon
93+
94+
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
95+
96+
while t[end] < t_end
97+
u_prev = u[end]
98+
t_prev = t[end]
99+
rate(rate_cache, u_prev, p, t_prev)
100+
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
101+
tau = min(tau, t_end - t_prev)
102+
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
103+
c(du, u_prev, p, t_prev, counts, nothing)
104+
u_new = u_prev + du
105+
if any(u_new .< 0)
106+
tau /= 2
107+
continue
108+
end
109+
push!(u, u_new)
110+
push!(t, t_prev + tau)
111+
end
112+
113+
sol = DiffEqBase.build_solution(prob, alg, t, u,
114+
calculate_error=false,
115+
interp=DiffEqBase.ConstantInterpolation(t, u))
116+
return sol
117+
end
118+
119+
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
120+
epsilon::Float64 # Error control parameter
121+
nc::Int # Critical reaction threshold
122+
nstiff::Float64 # Stiffness threshold for switching
123+
delta::Float64 # Partial equilibrium threshold
124+
end
125+
126+
SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) =
127+
SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta)
128+
129+
# Compute stoichiometry matrix from c function
130+
function compute_stoichiometry(c, u, numjumps, p, t)
131+
nu = zeros(Int, length(u), numjumps)
132+
for j in 1:numjumps
133+
counts = zeros(numjumps)
134+
counts[j] = 1
135+
du = similar(u)
136+
c(du, u, p, t, counts, nothing)
137+
nu[:, j] = round.(Int, du)
138+
end
139+
return nu
140+
end
141+
142+
# Detect reversible reaction pairs
143+
function find_reversible_pairs(nu)
144+
pairs = Vector{Tuple{Int,Int}}()
145+
for j in 1:size(nu, 2)
146+
for k in (j+1):size(nu, 2)
147+
if nu[:, j] == -nu[:, k]
148+
push!(pairs, (j, k))
149+
end
150+
end
151+
end
152+
return pairs
153+
end
154+
155+
# Compute g_i (approximation from Cao et al., 2006)
156+
function compute_gi(u, nu, i, rate, rate_cache, p, t)
157+
max_order = 1.0
158+
for j in 1:size(nu, 2)
159+
if abs(nu[i, j]) > 0
160+
rate(rate_cache, u, p, t)
161+
if rate_cache[j] > 0
162+
order = 1.0
163+
if sum(abs.(nu[:, j])) > abs(nu[i, j])
164+
order = 2.0
165+
end
166+
max_order = max(max_order, order)
167+
end
168+
end
169+
end
170+
return max_order
171+
end
172+
173+
# Tau-selection for explicit method (Equation 8)
174+
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
175+
rate(rate_cache, u, p, t)
176+
mu = zeros(length(u))
177+
sigma2 = zeros(length(u))
178+
tau = Inf
179+
for i in 1:length(u)
180+
for j in 1:size(nu, 2)
181+
mu[i] += nu[i, j] * rate_cache[j]
182+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
183+
end
184+
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
185+
bound = max(epsilon * u[i] / gi, 1.0)
186+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
187+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
188+
tau = min(tau, mu_term, sigma_term)
189+
end
190+
return max(tau, 1e-10)
191+
end
192+
193+
# Partial equilibrium check (Equation 13)
194+
function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
195+
a_plus = rate_cache[j_plus]
196+
a_minus = rate_cache[j_minus]
197+
return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus)
198+
end
199+
200+
# Tau-selection for implicit method (Equation 14)
201+
function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
202+
rate(rate_cache, u, p, t)
203+
mu = zeros(length(u))
204+
sigma2 = zeros(length(u))
205+
non_equilibrium = trues(size(nu, 2))
206+
for (j_plus, j_minus) in equilibrium_pairs
207+
if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
208+
non_equilibrium[j_plus] = false
209+
non_equilibrium[j_minus] = false
210+
end
211+
end
212+
tau = Inf
213+
for i in 1:length(u)
214+
for j in 1:size(nu, 2)
215+
if non_equilibrium[j]
216+
mu[i] += nu[i, j] * rate_cache[j]
217+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
218+
end
219+
end
220+
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
221+
bound = max(epsilon * u[i] / gi, 1.0)
222+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
223+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
224+
tau = min(tau, mu_term, sigma_term)
225+
end
226+
return max(tau, 1e-10)
227+
end
228+
229+
# Identify critical reactions
230+
function identify_critical_reactions(u, rate_cache, nu, nc)
231+
critical = falses(size(nu, 2))
232+
for j in 1:size(nu, 2)
233+
if rate_cache[j] > 0
234+
Lj = Inf
235+
for i in 1:length(u)
236+
if nu[i, j] < 0
237+
Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j])))
238+
end
239+
end
240+
if Lj < nc
241+
critical[j] = true
242+
end
243+
end
244+
end
245+
return critical
246+
end
247+
248+
# Implicit tau-leaping step using NonlinearSolve
249+
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
250+
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
251+
function f(u_new, params)
252+
rate_new = zeros(eltype(u_new), numjumps)
253+
rate(rate_new, u_new, p, t_prev + tau)
254+
residual = u_new - u_prev
255+
for j in 1:numjumps
256+
residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
257+
end
258+
return residual
259+
end
260+
261+
# Initial guess
262+
u_new = copy(u_prev)
263+
264+
# Solve the nonlinear system
265+
prob = NonlinearProblem(f, u_new, nothing)
266+
sol = solve(prob, NewtonRaphson())
267+
268+
# Check for convergence and numerical stability
269+
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
270+
return round.(Int, max.(u_prev, 0.0)) # Revert to previous state
271+
end
272+
273+
return round.(Int, max.(sol.u, 0.0))
274+
end
275+
276+
# Down-shifting condition (Equation 19)
277+
function use_down_shifting(t, tau_im, tau_ex, a0, t_end)
278+
return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
279+
end
280+
281+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing)
282+
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
283+
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
284+
prob = jump_prob.prob
285+
rng = DEFAULT_RNG
286+
(seed !== nothing) && seed!(rng, seed)
287+
288+
rj = jump_prob.regular_jump
289+
rate = rj.rate
290+
numjumps = rj.numjumps
291+
c = rj.c
292+
u0 = copy(prob.u0)
293+
tspan = prob.tspan
294+
p = prob.p
295+
296+
# Initialize storage
297+
rate_cache = zeros(Float64, numjumps)
298+
counts = zeros(Int, numjumps)
299+
du = similar(u0)
300+
u = [copy(u0)]
301+
t = [tspan[1]]
302+
303+
# Algorithm parameters
304+
epsilon = alg.epsilon
305+
nc = alg.nc
306+
nstiff = alg.nstiff
307+
delta = alg.delta
308+
t_end = tspan[2]
309+
310+
# Compute stoichiometry matrix
311+
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
312+
313+
# Detect reversible reaction pairs
314+
equilibrium_pairs = find_reversible_pairs(nu)
315+
316+
# Main simulation loop
317+
while t[end] < t_end
318+
u_prev = u[end]
319+
t_prev = t[end]
320+
321+
# Compute propensities
322+
rate(rate_cache, u_prev, p, t_prev)
323+
324+
# Identify critical reactions
325+
critical = identify_critical_reactions(u_prev, rate_cache, nu, nc)
326+
327+
# Compute tau values
328+
tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
329+
tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
330+
331+
# Compute critical propensity sum
332+
ac0 = sum(rate_cache[critical])
333+
tau2 = ac0 > 0 ? randexp(rng) / ac0 : Inf
334+
335+
# Choose method and stepsize
336+
a0 = sum(rate_cache)
337+
use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end)
338+
tau1 = use_implicit ? tau_im : tau_ex
339+
method = use_implicit ? :implicit : :explicit
340+
341+
# Cap tau to prevent large updates
342+
tau1 = min(tau1, 1.0)
343+
344+
# Check if tau1 is too small
345+
if a0 > 0 && tau1 < 10 / a0
346+
# Use SSA for a few steps
347+
steps = method == :implicit ? 10 : 100
348+
for _ in 1:steps
349+
if t_prev >= t_end
350+
break
351+
end
352+
rate(rate_cache, u_prev, p, t_prev)
353+
a0 = sum(rate_cache)
354+
if a0 == 0
355+
break
356+
end
357+
tau = randexp(rng) / a0
358+
r = rand(rng) * a0
359+
cumsum_rate = 0.0
360+
for j in 1:numjumps
361+
cumsum_rate += rate_cache[j]
362+
if cumsum_rate > r
363+
u_prev += nu[:, j]
364+
break
365+
end
366+
end
367+
t_prev += tau
368+
push!(u, copy(u_prev))
369+
push!(t, t_prev)
370+
end
371+
continue
372+
end
373+
374+
# Choose stepsize and compute firings
375+
if tau2 > tau1
376+
tau = min(tau1, t_end - t_prev)
377+
counts .= 0
378+
for j in 1:numjumps
379+
if !critical[j]
380+
counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0))
381+
end
382+
end
383+
if method == :implicit
384+
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
385+
else
386+
c(du, u_prev, p, t_prev, counts, nothing)
387+
u_new = u_prev + du
388+
end
389+
else
390+
tau = min(tau2, t_end - t_prev)
391+
counts .= 0
392+
if ac0 > 0
393+
r = rand(rng) * ac0
394+
cumsum_rate = 0.0
395+
for j in 1:numjumps
396+
if critical[j]
397+
cumsum_rate += rate_cache[j]
398+
if cumsum_rate > r
399+
counts[j] = 1
400+
break
401+
end
402+
end
403+
end
404+
end
405+
for j in 1:numjumps
406+
if !critical[j]
407+
counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0))
408+
end
409+
end
410+
if method == :implicit && tau > tau_ex
411+
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
412+
else
413+
c(du, u_prev, p, t_prev, counts, nothing)
414+
u_new = u_prev + du
415+
end
416+
end
417+
418+
# Check for negative populations
419+
if any(u_new .< 0)
420+
tau1 /= 2
421+
continue
422+
end
423+
424+
# Update state and time
425+
push!(u, u_new)
426+
push!(t, t_prev + tau)
427+
end
428+
429+
# Build solution
430+
sol = DiffEqBase.build_solution(prob, alg, t, u,
431+
calculate_error = false,
432+
interp = DiffEqBase.ConstantInterpolation(t, u))
433+
end
434+
65435
struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
66436
backend::Backend
67437
cpu_offload::Float64
@@ -74,3 +444,5 @@ end
74444
function EnsembleGPUKernel()
75445
EnsembleGPUKernel(nothing, 0.0)
76446
end
447+
448+
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping

0 commit comments

Comments
 (0)