Skip to content

Commit bd4a452

Browse files
refactor
1 parent 31562bb commit bd4a452

File tree

3 files changed

+7
-279
lines changed

3 files changed

+7
-279
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1919
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2020
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2121
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
22-
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2322
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2423
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2524
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

src/simple_regular_solve.jl

Lines changed: 7 additions & 270 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ end
6868

6969
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
7070

71-
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing)
71+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
72+
seed = nothing,
73+
dtmin = 1e-10)
7274
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
7375
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
7476
prob = jump_prob.prob
@@ -97,7 +99,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
9799
u_prev = u[end]
98100
t_prev = t[end]
99101
rate(rate_cache, u_prev, p, t_prev)
100-
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
102+
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, dtmin)
101103
tau = min(tau, t_end - t_prev)
102104
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
103105
c(du, u_prev, p, t_prev, counts, nothing)
@@ -116,16 +118,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
116118
return sol
117119
end
118120

119-
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
120-
epsilon::Float64 # Error control parameter
121-
nc::Int # Critical reaction threshold
122-
nstiff::Float64 # Stiffness threshold for switching
123-
delta::Float64 # Partial equilibrium threshold
124-
end
125-
126-
SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) =
127-
SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta)
128-
129121
# Compute stoichiometry matrix from c function
130122
function compute_stoichiometry(c, u, numjumps, p, t)
131123
nu = zeros(Int, length(u), numjumps)
@@ -139,19 +131,6 @@ function compute_stoichiometry(c, u, numjumps, p, t)
139131
return nu
140132
end
141133

142-
# Detect reversible reaction pairs
143-
function find_reversible_pairs(nu)
144-
pairs = Vector{Tuple{Int,Int}}()
145-
for j in 1:size(nu, 2)
146-
for k in (j+1):size(nu, 2)
147-
if nu[:, j] == -nu[:, k]
148-
push!(pairs, (j, k))
149-
end
150-
end
151-
end
152-
return pairs
153-
end
154-
155134
# Compute g_i (approximation from Cao et al., 2006)
156135
function compute_gi(u, nu, i, rate, rate_cache, p, t)
157136
max_order = 1.0
@@ -171,7 +150,7 @@ function compute_gi(u, nu, i, rate, rate_cache, p, t)
171150
end
172151

173152
# Tau-selection for explicit method (Equation 8)
174-
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
153+
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin)
175154
rate(rate_cache, u, p, t)
176155
mu = zeros(length(u))
177156
sigma2 = zeros(length(u))
@@ -187,249 +166,7 @@ function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
187166
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
188167
tau = min(tau, mu_term, sigma_term)
189168
end
190-
return max(tau, 1e-10)
191-
end
192-
193-
# Partial equilibrium check (Equation 13)
194-
function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
195-
a_plus = rate_cache[j_plus]
196-
a_minus = rate_cache[j_minus]
197-
return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus)
198-
end
199-
200-
# Tau-selection for implicit method (Equation 14)
201-
function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
202-
rate(rate_cache, u, p, t)
203-
mu = zeros(length(u))
204-
sigma2 = zeros(length(u))
205-
non_equilibrium = trues(size(nu, 2))
206-
for (j_plus, j_minus) in equilibrium_pairs
207-
if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
208-
non_equilibrium[j_plus] = false
209-
non_equilibrium[j_minus] = false
210-
end
211-
end
212-
tau = Inf
213-
for i in 1:length(u)
214-
for j in 1:size(nu, 2)
215-
if non_equilibrium[j]
216-
mu[i] += nu[i, j] * rate_cache[j]
217-
sigma2[i] += nu[i, j]^2 * rate_cache[j]
218-
end
219-
end
220-
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
221-
bound = max(epsilon * u[i] / gi, 1.0)
222-
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
223-
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
224-
tau = min(tau, mu_term, sigma_term)
225-
end
226-
return max(tau, 1e-10)
227-
end
228-
229-
# Identify critical reactions
230-
function identify_critical_reactions(u, rate_cache, nu, nc)
231-
critical = falses(size(nu, 2))
232-
for j in 1:size(nu, 2)
233-
if rate_cache[j] > 0
234-
Lj = Inf
235-
for i in 1:length(u)
236-
if nu[i, j] < 0
237-
Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j])))
238-
end
239-
end
240-
if Lj < nc
241-
critical[j] = true
242-
end
243-
end
244-
end
245-
return critical
246-
end
247-
248-
# Implicit tau-leaping step using NonlinearSolve
249-
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
250-
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
251-
function f(u_new, params)
252-
rate_new = zeros(eltype(u_new), numjumps)
253-
rate(rate_new, u_new, p, t_prev + tau)
254-
residual = u_new - u_prev
255-
for j in 1:numjumps
256-
residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
257-
end
258-
return residual
259-
end
260-
261-
# Initial guess
262-
u_new = copy(u_prev)
263-
264-
# Solve the nonlinear system
265-
prob = NonlinearProblem(f, u_new, nothing)
266-
sol = solve(prob, SimpleNewtonRaphson())
267-
268-
# Check for convergence and numerical stability
269-
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
270-
return round.(Int, max.(u_prev, 0.0)) # Revert to previous state
271-
end
272-
273-
return round.(Int, max.(sol.u, 0.0))
274-
end
275-
276-
# Down-shifting condition (Equation 19)
277-
function use_down_shifting(t, tau_im, tau_ex, a0, t_end)
278-
return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
279-
end
280-
281-
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing)
282-
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
283-
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
284-
prob = jump_prob.prob
285-
rng = DEFAULT_RNG
286-
(seed !== nothing) && seed!(rng, seed)
287-
288-
rj = jump_prob.regular_jump
289-
rate = rj.rate
290-
numjumps = rj.numjumps
291-
c = rj.c
292-
u0 = copy(prob.u0)
293-
tspan = prob.tspan
294-
p = prob.p
295-
296-
# Initialize storage
297-
rate_cache = zeros(Float64, numjumps)
298-
counts = zeros(Int, numjumps)
299-
du = similar(u0)
300-
u = [copy(u0)]
301-
t = [tspan[1]]
302-
303-
# Algorithm parameters
304-
epsilon = alg.epsilon
305-
nc = alg.nc
306-
nstiff = alg.nstiff
307-
delta = alg.delta
308-
t_end = tspan[2]
309-
310-
# Compute stoichiometry matrix
311-
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
312-
313-
# Detect reversible reaction pairs
314-
equilibrium_pairs = find_reversible_pairs(nu)
315-
316-
# Main simulation loop
317-
while t[end] < t_end
318-
u_prev = u[end]
319-
t_prev = t[end]
320-
321-
# Compute propensities
322-
rate(rate_cache, u_prev, p, t_prev)
323-
324-
# Identify critical reactions
325-
critical = identify_critical_reactions(u_prev, rate_cache, nu, nc)
326-
327-
# Compute tau values
328-
tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
329-
tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
330-
331-
# Compute critical propensity sum
332-
ac0 = sum(rate_cache[critical])
333-
tau2 = ac0 > 0 ? randexp(rng) / ac0 : Inf
334-
335-
# Choose method and stepsize
336-
a0 = sum(rate_cache)
337-
use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end)
338-
tau1 = use_implicit ? tau_im : tau_ex
339-
method = use_implicit ? :implicit : :explicit
340-
341-
# Cap tau to prevent large updates
342-
tau1 = min(tau1, 1.0)
343-
344-
# Check if tau1 is too small
345-
if a0 > 0 && tau1 < 10 / a0
346-
# Use SSA for a few steps
347-
steps = method == :implicit ? 10 : 100
348-
for _ in 1:steps
349-
if t_prev >= t_end
350-
break
351-
end
352-
rate(rate_cache, u_prev, p, t_prev)
353-
a0 = sum(rate_cache)
354-
if a0 == 0
355-
break
356-
end
357-
tau = randexp(rng) / a0
358-
r = rand(rng) * a0
359-
cumsum_rate = 0.0
360-
for j in 1:numjumps
361-
cumsum_rate += rate_cache[j]
362-
if cumsum_rate > r
363-
u_prev += nu[:, j]
364-
break
365-
end
366-
end
367-
t_prev += tau
368-
push!(u, copy(u_prev))
369-
push!(t, t_prev)
370-
end
371-
continue
372-
end
373-
374-
# Choose stepsize and compute firings
375-
if tau2 > tau1
376-
tau = min(tau1, t_end - t_prev)
377-
counts .= 0
378-
for j in 1:numjumps
379-
if !critical[j]
380-
counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0))
381-
end
382-
end
383-
if method == :implicit
384-
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
385-
else
386-
c(du, u_prev, p, t_prev, counts, nothing)
387-
u_new = u_prev + du
388-
end
389-
else
390-
tau = min(tau2, t_end - t_prev)
391-
counts .= 0
392-
if ac0 > 0
393-
r = rand(rng) * ac0
394-
cumsum_rate = 0.0
395-
for j in 1:numjumps
396-
if critical[j]
397-
cumsum_rate += rate_cache[j]
398-
if cumsum_rate > r
399-
counts[j] = 1
400-
break
401-
end
402-
end
403-
end
404-
end
405-
for j in 1:numjumps
406-
if !critical[j]
407-
counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0))
408-
end
409-
end
410-
if method == :implicit && tau > tau_ex
411-
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
412-
else
413-
c(du, u_prev, p, t_prev, counts, nothing)
414-
u_new = u_prev + du
415-
end
416-
end
417-
418-
# Check for negative populations
419-
if any(u_new .< 0)
420-
tau1 /= 2
421-
continue
422-
end
423-
424-
# Update state and time
425-
push!(u, u_new)
426-
push!(t, t_prev + tau)
427-
end
428-
429-
# Build solution
430-
sol = DiffEqBase.build_solution(prob, alg, t, u,
431-
calculate_error = false,
432-
interp = DiffEqBase.ConstantInterpolation(t, u))
169+
return max(tau, dtmin)
433170
end
434171

435172
struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
@@ -445,4 +182,4 @@ function EnsembleGPUKernel()
445182
EnsembleGPUKernel(nothing, 0.0)
446183
end
447184

448-
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping
185+
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping

test/regular_jumps.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,9 @@ let
3535
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
3636
mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims)
3737

38-
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
39-
mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims)
40-
4138
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
4239
mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims)
4340

44-
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
4541
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
4642
end
4743

@@ -79,12 +75,8 @@ let
7975
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
8076
mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims)
8177

82-
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
83-
mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims)
84-
8578
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
8679
mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims)
8780

88-
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
8981
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
9082
end

0 commit comments

Comments
 (0)