Skip to content

Commit 64ccbcd

Browse files
refactor
1 parent e3ea56a commit 64ccbcd

File tree

4 files changed

+7
-280
lines changed

4 files changed

+7
-280
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
1919
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2020
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2121
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
22-
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2322
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2423
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2524
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

src/JumpProcesses.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import DataStructures: update!
3232
import Graphs: neighbors, outdegree
3333
import RecursiveArrayTools: recursivecopy!
3434
import SymbolicIndexingInterface as SII
35-
using SimpleNonlinearSolve
3635

3736
# Import additional types and functions from DiffEqBase and SciMLBase
3837
using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction,

src/simple_regular_solve.jl

Lines changed: 7 additions & 270 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ end
5656

5757
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
5858

59-
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing)
59+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
60+
seed = nothing,
61+
dtmin = 1e-10)
6062
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
6163
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
6264
prob = jump_prob.prob
@@ -85,7 +87,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
8587
u_prev = u[end]
8688
t_prev = t[end]
8789
rate(rate_cache, u_prev, p, t_prev)
88-
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
90+
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, dtmin)
8991
tau = min(tau, t_end - t_prev)
9092
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
9193
c(du, u_prev, p, t_prev, counts, nothing)
@@ -104,16 +106,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
104106
return sol
105107
end
106108

107-
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
108-
epsilon::Float64 # Error control parameter
109-
nc::Int # Critical reaction threshold
110-
nstiff::Float64 # Stiffness threshold for switching
111-
delta::Float64 # Partial equilibrium threshold
112-
end
113-
114-
SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) =
115-
SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta)
116-
117109
# Compute stoichiometry matrix from c function
118110
function compute_stoichiometry(c, u, numjumps, p, t)
119111
nu = zeros(Int, length(u), numjumps)
@@ -127,19 +119,6 @@ function compute_stoichiometry(c, u, numjumps, p, t)
127119
return nu
128120
end
129121

130-
# Detect reversible reaction pairs
131-
function find_reversible_pairs(nu)
132-
pairs = Vector{Tuple{Int,Int}}()
133-
for j in 1:size(nu, 2)
134-
for k in (j+1):size(nu, 2)
135-
if nu[:, j] == -nu[:, k]
136-
push!(pairs, (j, k))
137-
end
138-
end
139-
end
140-
return pairs
141-
end
142-
143122
# Compute g_i (approximation from Cao et al., 2006)
144123
function compute_gi(u, nu, i, rate, rate_cache, p, t)
145124
max_order = 1.0
@@ -159,7 +138,7 @@ function compute_gi(u, nu, i, rate, rate_cache, p, t)
159138
end
160139

161140
# Tau-selection for explicit method (Equation 8)
162-
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
141+
function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate, dtmin)
163142
rate(rate_cache, u, p, t)
164143
mu = zeros(length(u))
165144
sigma2 = zeros(length(u))
@@ -175,249 +154,7 @@ function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate)
175154
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
176155
tau = min(tau, mu_term, sigma_term)
177156
end
178-
return max(tau, 1e-10)
179-
end
180-
181-
# Partial equilibrium check (Equation 13)
182-
function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
183-
a_plus = rate_cache[j_plus]
184-
a_minus = rate_cache[j_minus]
185-
return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus)
186-
end
187-
188-
# Tau-selection for implicit method (Equation 14)
189-
function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta)
190-
rate(rate_cache, u, p, t)
191-
mu = zeros(length(u))
192-
sigma2 = zeros(length(u))
193-
non_equilibrium = trues(size(nu, 2))
194-
for (j_plus, j_minus) in equilibrium_pairs
195-
if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta)
196-
non_equilibrium[j_plus] = false
197-
non_equilibrium[j_minus] = false
198-
end
199-
end
200-
tau = Inf
201-
for i in 1:length(u)
202-
for j in 1:size(nu, 2)
203-
if non_equilibrium[j]
204-
mu[i] += nu[i, j] * rate_cache[j]
205-
sigma2[i] += nu[i, j]^2 * rate_cache[j]
206-
end
207-
end
208-
gi = compute_gi(u, nu, i, rate, rate_cache, p, t)
209-
bound = max(epsilon * u[i] / gi, 1.0)
210-
mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf
211-
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
212-
tau = min(tau, mu_term, sigma_term)
213-
end
214-
return max(tau, 1e-10)
215-
end
216-
217-
# Identify critical reactions
218-
function identify_critical_reactions(u, rate_cache, nu, nc)
219-
critical = falses(size(nu, 2))
220-
for j in 1:size(nu, 2)
221-
if rate_cache[j] > 0
222-
Lj = Inf
223-
for i in 1:length(u)
224-
if nu[i, j] < 0
225-
Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j])))
226-
end
227-
end
228-
if Lj < nc
229-
critical[j] = true
230-
end
231-
end
232-
end
233-
return critical
234-
end
235-
236-
# Implicit tau-leaping step using NonlinearSolve
237-
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
238-
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
239-
function f(u_new, params)
240-
rate_new = zeros(eltype(u_new), numjumps)
241-
rate(rate_new, u_new, p, t_prev + tau)
242-
residual = u_new - u_prev
243-
for j in 1:numjumps
244-
residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
245-
end
246-
return residual
247-
end
248-
249-
# Initial guess
250-
u_new = copy(u_prev)
251-
252-
# Solve the nonlinear system
253-
prob = NonlinearProblem(f, u_new, nothing)
254-
sol = solve(prob, SimpleNewtonRaphson())
255-
256-
# Check for convergence and numerical stability
257-
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
258-
return round.(Int, max.(u_prev, 0.0)) # Revert to previous state
259-
end
260-
261-
return round.(Int, max.(sol.u, 0.0))
262-
end
263-
264-
# Down-shifting condition (Equation 19)
265-
function use_down_shifting(t, tau_im, tau_ex, a0, t_end)
266-
return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0)
267-
end
268-
269-
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing)
270-
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
271-
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
272-
prob = jump_prob.prob
273-
rng = DEFAULT_RNG
274-
(seed !== nothing) && seed!(rng, seed)
275-
276-
rj = jump_prob.regular_jump
277-
rate = rj.rate
278-
numjumps = rj.numjumps
279-
c = rj.c
280-
u0 = copy(prob.u0)
281-
tspan = prob.tspan
282-
p = prob.p
283-
284-
# Initialize storage
285-
rate_cache = zeros(Float64, numjumps)
286-
counts = zeros(Int, numjumps)
287-
du = similar(u0)
288-
u = [copy(u0)]
289-
t = [tspan[1]]
290-
291-
# Algorithm parameters
292-
epsilon = alg.epsilon
293-
nc = alg.nc
294-
nstiff = alg.nstiff
295-
delta = alg.delta
296-
t_end = tspan[2]
297-
298-
# Compute stoichiometry matrix
299-
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
300-
301-
# Detect reversible reaction pairs
302-
equilibrium_pairs = find_reversible_pairs(nu)
303-
304-
# Main simulation loop
305-
while t[end] < t_end
306-
u_prev = u[end]
307-
t_prev = t[end]
308-
309-
# Compute propensities
310-
rate(rate_cache, u_prev, p, t_prev)
311-
312-
# Identify critical reactions
313-
critical = identify_critical_reactions(u_prev, rate_cache, nu, nc)
314-
315-
# Compute tau values
316-
tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
317-
tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta)
318-
319-
# Compute critical propensity sum
320-
ac0 = sum(rate_cache[critical])
321-
tau2 = ac0 > 0 ? randexp(rng) / ac0 : Inf
322-
323-
# Choose method and stepsize
324-
a0 = sum(rate_cache)
325-
use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end)
326-
tau1 = use_implicit ? tau_im : tau_ex
327-
method = use_implicit ? :implicit : :explicit
328-
329-
# Cap tau to prevent large updates
330-
tau1 = min(tau1, 1.0)
331-
332-
# Check if tau1 is too small
333-
if a0 > 0 && tau1 < 10 / a0
334-
# Use SSA for a few steps
335-
steps = method == :implicit ? 10 : 100
336-
for _ in 1:steps
337-
if t_prev >= t_end
338-
break
339-
end
340-
rate(rate_cache, u_prev, p, t_prev)
341-
a0 = sum(rate_cache)
342-
if a0 == 0
343-
break
344-
end
345-
tau = randexp(rng) / a0
346-
r = rand(rng) * a0
347-
cumsum_rate = 0.0
348-
for j in 1:numjumps
349-
cumsum_rate += rate_cache[j]
350-
if cumsum_rate > r
351-
u_prev += nu[:, j]
352-
break
353-
end
354-
end
355-
t_prev += tau
356-
push!(u, copy(u_prev))
357-
push!(t, t_prev)
358-
end
359-
continue
360-
end
361-
362-
# Choose stepsize and compute firings
363-
if tau2 > tau1
364-
tau = min(tau1, t_end - t_prev)
365-
counts .= 0
366-
for j in 1:numjumps
367-
if !critical[j]
368-
counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0))
369-
end
370-
end
371-
if method == :implicit
372-
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
373-
else
374-
c(du, u_prev, p, t_prev, counts, nothing)
375-
u_new = u_prev + du
376-
end
377-
else
378-
tau = min(tau2, t_end - t_prev)
379-
counts .= 0
380-
if ac0 > 0
381-
r = rand(rng) * ac0
382-
cumsum_rate = 0.0
383-
for j in 1:numjumps
384-
if critical[j]
385-
cumsum_rate += rate_cache[j]
386-
if cumsum_rate > r
387-
counts[j] = 1
388-
break
389-
end
390-
end
391-
end
392-
end
393-
for j in 1:numjumps
394-
if !critical[j]
395-
counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0))
396-
end
397-
end
398-
if method == :implicit && tau > tau_ex
399-
u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps)
400-
else
401-
c(du, u_prev, p, t_prev, counts, nothing)
402-
u_new = u_prev + du
403-
end
404-
end
405-
406-
# Check for negative populations
407-
if any(u_new .< 0)
408-
tau1 /= 2
409-
continue
410-
end
411-
412-
# Update state and time
413-
push!(u, u_new)
414-
push!(t, t_prev + tau)
415-
end
416-
417-
# Build solution
418-
sol = DiffEqBase.build_solution(prob, alg, t, u,
419-
calculate_error = false,
420-
interp = DiffEqBase.ConstantInterpolation(t, u))
157+
return max(tau, dtmin)
421158
end
422159

423160
struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
@@ -433,4 +170,4 @@ function EnsembleGPUKernel()
433170
EnsembleGPUKernel(nothing, 0.0)
434171
end
435172

436-
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping
173+
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping

test/regular_jumps.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,9 @@ let
3535
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
3636
mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims)
3737

38-
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
39-
mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims)
40-
4138
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
4239
mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims)
4340

44-
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
4541
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
4642
end
4743

@@ -79,12 +75,8 @@ let
7975
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
8076
mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims)
8177

82-
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
83-
mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims)
84-
8578
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
8679
mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims)
8780

88-
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
8981
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
9082
end

0 commit comments

Comments
 (0)