Skip to content

Commit 1466917

Browse files
Implemented SimpleAdaptiveTauLeaping and SimpleImplicitTauLeaping
1 parent 48c6bf9 commit 1466917

File tree

4 files changed

+451
-30
lines changed

4 files changed

+451
-30
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1313
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
16+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1617
PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
1718
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

src/JumpProcesses.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Reexport
44
@reexport using DiffEqBase
55

66
using LinearAlgebra, Markdown, DocStringExtensions
7+
using NonlinearSolve
78
using DataStructures, PoissonRandom, Random, ArrayInterface
89
using FunctionWrappers, UnPack
910
using Graphs

src/simple_regular_solve.jl

Lines changed: 371 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,376 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
5050
interp = DiffEqBase.ConstantInterpolation(t, u))
5151
end
5252

53+
struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm
54+
epsilon::Float64 # Error control parameter
55+
end
56+
57+
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
58+
59+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing)
60+
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
61+
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
62+
prob = jump_prob.prob
63+
rng = DEFAULT_RNG
64+
(seed !== nothing) && seed!(rng, seed)
65+
66+
rj = jump_prob.regular_jump
67+
rate = rj.rate
68+
numjumps = rj.numjumps
69+
c = rj.c
70+
u0 = copy(prob.u0)
71+
tspan = prob.tspan
72+
p = prob.p
73+
74+
u = [copy(u0)]
75+
t = [tspan[1]]
76+
rate_cache = zeros(Float64, numjumps)
77+
counts = zeros(Int, numjumps)
78+
du = similar(u0)
79+
t_end = tspan[2]
80+
epsilon = alg.epsilon
81+
82+
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
83+
84+
while t[end] < t_end
85+
u_prev = u[end]
86+
t_prev = t[end]
87+
rate(rate_cache, u_prev, p, t_prev)
88+
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
89+
tau = min(tau, t_end - t_prev)
90+
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
91+
c(du, u_prev, p, t_prev, counts, nothing)
92+
u_new = u_prev + du
93+
if any(u_new .< 0)
94+
tau /= 2
95+
continue
96+
end
97+
push!(u, u_new)
98+
push!(t, t_prev + tau)
99+
end
100+
101+
sol = DiffEqBase.build_solution(prob, alg, t, u,
102+
calculate_error=false,
103+
interp=DiffEqBase.ConstantInterpolation(t, u))
104+
return sol
105+
end
106+
107+
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
108+
epsilon::Float64 # Error control parameter
109+
nc::Int # Critical reaction threshold
110+
nstiff::Float64 # Stiffness threshold for switching
111+
delta::Float64 # Partial equilibrium threshold
112+
end
113+
114+
SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) =
115+
SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta)
116+
117+
# Compute stoichiometry matrix from c function
118+
function compute_stoichiometry(c, u, numjumps, p, t)
119+
nu = zeros(Int, length(u), numjumps)
120+
for j in 1:numjumps
121+
counts = zeros(numjumps)
122+
counts[j] = 1
123+
du = similar(u)
124+
c(du, u, p, t, counts, nothing)
125+
nu[:, j] = round.(Int, du)
126+
end
127+
return nu
128+
end
129+
130+
# Detect reversible reaction pairs
131+
function find_reversible_pairs(nu)
132+
pairs = Vector{Tuple{Int,Int}}()
133+
for j in 1:size(nu, 2)
134+
for k in (j+1):size(nu, 2)
135+
if nu[:, j] == -nu[:, k]
136+
push!(pairs, (j, k))
137+
end
138+
end
139+
end
140+
return pairs
141+
end
142+
143+
# Compute g_i (approximation from Cao et al., 2006)
144+
function compute_gi(u, nu, i, rate, rate_cache, p, t)
145+
max_order = 1.0
146+
for j in 1:size(nu, 2)
147+
if abs(nu[i, j]) > 0
148+
rate(rate_cache, u, p, t)
149+
if rate_cache[j] > 0
150+
order = 1.0
151+
if sum(abs.(nu[:, j])) > abs(nu[i, j])
152+
order = 2.0
153+
end
154+
max_order = max(max_order, order)
155+
end
156+
end
157+
end
158+
return max_order
159+
end
160+
161+
# Tau-selection for explicit method (Equation 8)
162+
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
163+
rate(rate_cache, u, p, t)
164+
mu = zeros(length(u))
165+
sigma2 = zeros(length(u))
166+
tau = Inf
167+
for i in 1:length(u)
168+
for j in 1:size(nu, 2)
169+
mu[i] += nu[i, j] * rate_cache[j]
170+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
171+
end
172+
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
173+
bound = max(epsilon * u[i] / gi, 1.0)
174+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
175+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
176+
tau = min(tau, mu_term, sigma_term)
177+
end
178+
return max(tau, 1e-10)
179+
end
180+
181+
# Partial equilibrium check (Equation 13)
182+
function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
183+
a_plus = rate_cache[j_plus]
184+
a_minus = rate_cache[j_minus]
185+
return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus)
186+
end
187+
188+
# Tau-selection for implicit method (Equation 14)
189+
function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
190+
rate(rate_cache, u, p, t)
191+
mu = zeros(length(u))
192+
sigma2 = zeros(length(u))
193+
non_equilibrium = trues(size(nu, 2))
194+
for (j_plus, j_minus) in equilibrium_pairs
195+
if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
196+
non_equilibrium[j_plus] = false
197+
non_equilibrium[j_minus] = false
198+
end
199+
end
200+
tau = Inf
201+
for i in 1:length(u)
202+
for j in 1:size(nu, 2)
203+
if non_equilibrium[j]
204+
mu[i] += nu[i, j] * rate_cache[j]
205+
sigma2[i] += nu[i, j]^2 * rate_cache[j]
206+
end
207+
end
208+
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
209+
bound = max(epsilon * u[i] / gi, 1.0)
210+
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
211+
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
212+
tau = min(tau, mu_term, sigma_term)
213+
end
214+
return max(tau, 1e-10)
215+
end
216+
217+
# Identify critical reactions
218+
function identify_critical_reactions(u, rate_cache, nu, nc)
219+
critical = falses(size(nu, 2))
220+
for j in 1:size(nu, 2)
221+
if rate_cache[j] > 0
222+
Lj = Inf
223+
for i in 1:length(u)
224+
if nu[i, j] < 0
225+
Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j])))
226+
end
227+
end
228+
if Lj < nc
229+
critical[j] = true
230+
end
231+
end
232+
end
233+
return critical
234+
end
235+
236+
# Implicit tau-leaping step using NonlinearSolve
237+
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
238+
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
239+
function f(u_new, params)
240+
rate_new = zeros(eltype(u_new), numjumps)
241+
rate(rate_new, u_new, p, t_prev + tau)
242+
residual = u_new - u_prev
243+
for j in 1:numjumps
244+
residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
245+
end
246+
return residual
247+
end
248+
249+
# Initial guess
250+
u_new = copy(u_prev)
251+
252+
# Solve the nonlinear system
253+
prob = NonlinearProblem(f, u_new, nothing)
254+
sol = solve(prob, NewtonRaphson())
255+
256+
# Check for convergence and numerical stability
257+
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
258+
return round.(Int, max.(u_prev, 0.0)) # Revert to previous state
259+
end
260+
261+
return round.(Int, max.(sol.u, 0.0))
262+
end
263+
264+
# Down-shifting condition (Equation 19)
265+
function use_down_shifting(t, tau_im, tau_ex, a0, t_end)
266+
return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
267+
end
268+
269+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing)
270+
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
271+
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
272+
prob = jump_prob.prob
273+
rng = DEFAULT_RNG
274+
(seed !== nothing) && seed!(rng, seed)
275+
276+
rj = jump_prob.regular_jump
277+
rate = rj.rate
278+
numjumps = rj.numjumps
279+
c = rj.c
280+
u0 = copy(prob.u0)
281+
tspan = prob.tspan
282+
p = prob.p
283+
284+
# Initialize storage
285+
rate_cache = zeros(Float64, numjumps)
286+
counts = zeros(Int, numjumps)
287+
du = similar(u0)
288+
u = [copy(u0)]
289+
t = [tspan[1]]
290+
291+
# Algorithm parameters
292+
epsilon = alg.epsilon
293+
nc = alg.nc
294+
nstiff = alg.nstiff
295+
delta = alg.delta
296+
t_end = tspan[2]
297+
298+
# Compute stoichiometry matrix
299+
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
300+
301+
# Detect reversible reaction pairs
302+
equilibrium_pairs = find_reversible_pairs(nu)
303+
304+
# Main simulation loop
305+
while t[end] < t_end
306+
u_prev = u[end]
307+
t_prev = t[end]
308+
309+
# Compute propensities
310+
rate(rate_cache, u_prev, p, t_prev)
311+
312+
# Identify critical reactions
313+
critical = identify_critical_reactions(u_prev, rate_cache, nu, nc)
314+
315+
# Compute tau values
316+
tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
317+
tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
318+
319+
# Compute critical propensity sum
320+
ac0 = sum(rate_cache[critical])
321+
tau2 = ac0 > 0 ? randexp(rng) / ac0 : Inf
322+
323+
# Choose method and stepsize
324+
a0 = sum(rate_cache)
325+
use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end)
326+
tau1 = use_implicit ? tau_im : tau_ex
327+
method = use_implicit ? :implicit : :explicit
328+
329+
# Cap tau to prevent large updates
330+
tau1 = min(tau1, 1.0)
331+
332+
# Check if tau1 is too small
333+
if a0 > 0 && tau1 < 10 / a0
334+
# Use SSA for a few steps
335+
steps = method == :implicit ? 10 : 100
336+
for _ in 1:steps
337+
if t_prev >= t_end
338+
break
339+
end
340+
rate(rate_cache, u_prev, p, t_prev)
341+
a0 = sum(rate_cache)
342+
if a0 == 0
343+
break
344+
end
345+
tau = randexp(rng) / a0
346+
r = rand(rng) * a0
347+
cumsum_rate = 0.0
348+
for j in 1:numjumps
349+
cumsum_rate += rate_cache[j]
350+
if cumsum_rate > r
351+
u_prev += nu[:, j]
352+
break
353+
end
354+
end
355+
t_prev += tau
356+
push!(u, copy(u_prev))
357+
push!(t, t_prev)
358+
end
359+
continue
360+
end
361+
362+
# Choose stepsize and compute firings
363+
if tau2 > tau1
364+
tau = min(tau1, t_end - t_prev)
365+
counts .= 0
366+
for j in 1:numjumps
367+
if !critical[j]
368+
counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0))
369+
end
370+
end
371+
if method == :implicit
372+
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
373+
else
374+
c(du, u_prev, p, t_prev, counts, nothing)
375+
u_new = u_prev + du
376+
end
377+
else
378+
tau = min(tau2, t_end - t_prev)
379+
counts .= 0
380+
if ac0 > 0
381+
r = rand(rng) * ac0
382+
cumsum_rate = 0.0
383+
for j in 1:numjumps
384+
if critical[j]
385+
cumsum_rate += rate_cache[j]
386+
if cumsum_rate > r
387+
counts[j] = 1
388+
break
389+
end
390+
end
391+
end
392+
end
393+
for j in 1:numjumps
394+
if !critical[j]
395+
counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0))
396+
end
397+
end
398+
if method == :implicit && tau > tau_ex
399+
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
400+
else
401+
c(du, u_prev, p, t_prev, counts, nothing)
402+
u_new = u_prev + du
403+
end
404+
end
405+
406+
# Check for negative populations
407+
if any(u_new .< 0)
408+
tau1 /= 2
409+
continue
410+
end
411+
412+
# Update state and time
413+
push!(u, u_new)
414+
push!(t, t_prev + tau)
415+
end
416+
417+
# Build solution
418+
sol = DiffEqBase.build_solution(prob, alg, t, u,
419+
calculate_error = false,
420+
interp = DiffEqBase.ConstantInterpolation(t, u))
421+
end
422+
53423
struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
54424
backend::Backend
55425
cpu_offload::Float64
@@ -63,4 +433,4 @@ function EnsembleGPUKernel()
63433
EnsembleGPUKernel(nothing, 0.0)
64434
end
65435

66-
export SimpleTauLeaping, EnsembleGPUKernel
436+
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping

0 commit comments

Comments
 (0)