Skip to content

Commit 41c6b2c

Browse files
ImplicitTauLeaping setup done for jump problem solver
1 parent 3601205 commit 41c6b2c

File tree

2 files changed

+334
-10
lines changed

2 files changed

+334
-10
lines changed

src/simple_regular_solve.jl

Lines changed: 295 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,301 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
5050
interp = DiffEqBase.ConstantInterpolation(t, u))
5151
end
5252

53+
# Define ImplicitTauLeaping algorithm
54+
struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm
55+
epsilon::Float64 # Error control parameter
56+
nc::Int # Critical reaction threshold
57+
nstiff::Int # Stiffness threshold multiplier
58+
delta::Float64 # Partial equilibrium threshold
59+
end
60+
61+
ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100, delta=0.05) = ImplicitTauLeaping(epsilon, nc, nstiff, delta)
62+
63+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping;
64+
seed = nothing,
65+
dt = error("dt is required for ImplicitTauLeaping."),
66+
kwargs...)
67+
68+
# Boilerplate from SimpleTauLeaping
69+
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
70+
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
71+
prob = jump_prob.prob
72+
rng = DEFAULT_RNG
73+
(seed !== nothing) && seed!(rng, seed)
74+
75+
rj = jump_prob.regular_jump
76+
rate = rj.rate # rate(out, u, p, t)
77+
numjumps = rj.numjumps
78+
c = rj.c # c(dc, u, p, t, counts, mark)
79+
reversible_pairs = get(kwargs, :reversible_pairs, Tuple{Int,Int}[])
80+
81+
if !isnothing(rj.mark_dist)
82+
error("Mark distributions are currently not supported in ImplicitTauLeaping")
83+
end
84+
85+
# Initialize state and buffers
86+
u0 = copy(prob.u0)
87+
p = prob.p
88+
tspan = prob.tspan
89+
state_dim = length(u0)
90+
dt = Float64(dt)
91+
92+
# Compute stoichiometry matrix
93+
nu = zeros(Int, state_dim, numjumps)
94+
for j in 1:numjumps
95+
dc = zeros(state_dim)
96+
c(dc, u0, p, 0.0, [i == j ? 1 : 0 for i in 1:numjumps], nothing)
97+
nu[:, j] = dc
98+
end
99+
100+
# Initialize solution arrays
101+
n = Int((tspan[2] - tspan[1]) / dt) + 1
102+
u = Vector{typeof(u0)}(undef, n)
103+
u[1] = u0
104+
t = range(tspan[1], tspan[2], length=n)
105+
106+
# Buffers for iteration
107+
current_u = copy(u0)
108+
rate_cache = zeros(Float64, numjumps)
109+
counts = zeros(Float64, numjumps)
110+
local_dc = zeros(Float64, state_dim)
111+
I_rs = 1:state_dim
112+
g = ones(state_dim) # Scaling factor for tau-leaping
113+
114+
function compute_tau_explicit(u, rate, nu, num_jumps, epsilon, g, J_ncr, I_rs, p)
115+
rate_cache = zeros(eltype(u), num_jumps)
116+
rate(rate_cache, u, p, 0.0)
117+
118+
mu = zeros(eltype(u), length(u))
119+
sigma2 = zeros(eltype(u), length(u))
120+
121+
for i in I_rs
122+
mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_ncr; init=0.0)
123+
sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_ncr; init=0.0)
124+
end
125+
126+
tau = Inf
127+
for i in I_rs
128+
denom_mu = max(epsilon * u[i] / g[i], 1.0)
129+
denom_sigma = denom_mu^2
130+
if abs(mu[i]) > 0
131+
tau = min(tau, denom_mu / abs(mu[i]))
132+
end
133+
if sigma2[i] > 0
134+
tau = min(tau, denom_sigma / sigma2[i])
135+
end
136+
end
137+
return tau
138+
end
139+
140+
function compute_tau_implicit(u, rate, nu, num_jumps, epsilon, g, J_necr, I_rs, p)
141+
rate_cache = zeros(eltype(u), num_jumps)
142+
rate(rate_cache, u, p, 0.0)
143+
144+
mu = zeros(eltype(u), length(u))
145+
sigma2 = zeros(eltype(u), length(u))
146+
147+
for i in I_rs
148+
mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_necr; init=0.0)
149+
sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_necr; init=0.0)
150+
end
151+
152+
tau = Inf
153+
for i in I_rs
154+
denom_mu = max(epsilon * u[i] / g[i], 1.0)
155+
denom_sigma = denom_mu^2
156+
if abs(mu[i]) > 0
157+
tau = min(tau, denom_mu / abs(mu[i]))
158+
end
159+
if sigma2[i] > 0
160+
tau = min(tau, denom_sigma / sigma2[i])
161+
end
162+
end
163+
return isinf(tau) ? 1e6 : tau
164+
end
165+
166+
function identify_critical_reactions(u, nu, num_jumps, nc)
167+
L = zeros(Int, num_jumps)
168+
J_critical = Int[]
169+
170+
for j in 1:num_jumps
171+
min_val = Inf
172+
for i in 1:length(u)
173+
if nu[i,j] < 0
174+
val = floor(u[i] / abs(nu[i,j]))
175+
min_val = min(min_val, val)
176+
end
177+
end
178+
L[j] = min_val == Inf ? typemax(Int) : Int(min_val)
179+
if L[j] < nc
180+
push!(J_critical, j)
181+
end
182+
end
183+
J_ncr = setdiff(1:num_jumps, J_critical)
184+
return J_critical, J_ncr
185+
end
186+
187+
function check_partial_equilibrium(rate_cache, reversible_pairs, delta)
188+
J_equilibrium = Int[]
189+
for (j_plus, j_minus) in reversible_pairs
190+
a_plus = rate_cache[j_plus]
191+
a_minus = rate_cache[j_minus]
192+
if abs(a_plus - a_minus) <= delta * min(a_plus, a_minus)
193+
push!(J_equilibrium, j_plus, j_minus)
194+
end
195+
end
196+
return J_equilibrium
197+
end
198+
199+
function newton_solve!(x_new, x, rate, nu, rate_cache, counts, p, t, tau, max_iter=10, tol=1e-6)
200+
state_dim = length(x)
201+
num_jumps = length(counts)
202+
203+
for iter in 1:max_iter
204+
rate(rate_cache, x_new, p, t)
205+
rate_cache .*= tau
206+
207+
residual = x_new .- x
208+
for j in 1:num_jumps
209+
residual .-= nu[:,j] * (counts[j] - rate_cache[j] + tau * rate_cache[j])
210+
end
211+
212+
if norm(residual) < tol
213+
break
214+
end
215+
216+
J = zeros(eltype(x), state_dim, state_dim)
217+
for j in 1:num_jumps
218+
for i in 1:state_dim
219+
for k in 1:state_dim
220+
J[i,k] += nu[i,j] * nu[k,j] * rate_cache[j]
221+
end
222+
end
223+
end
224+
J = I - tau * J
225+
226+
delta_x = J \ residual
227+
x_new .-= delta_x
228+
229+
if norm(delta_x) < tol
230+
break
231+
end
232+
end
233+
return x_new
234+
end
235+
236+
# Main solver loop
237+
for i in 2:n
238+
tprev = t[i - 1]
239+
J_critical, J_ncr = identify_critical_reactions(current_u, nu, numjumps, alg.nc)
240+
241+
rate(rate_cache, current_u, p, tprev)
242+
a0_critical = sum(rate_cache[j] for j in J_critical; init=0.0)
243+
244+
J_equilibrium = check_partial_equilibrium(rate_cache, reversible_pairs, alg.delta)
245+
J_necr = setdiff(J_ncr, J_equilibrium)
246+
247+
tau_ex = compute_tau_explicit(current_u, rate, nu, numjumps, alg.epsilon, g, J_ncr, I_rs, p)
248+
tau_im = compute_tau_implicit(current_u, rate, nu, numjumps, alg.epsilon, g, J_necr, I_rs, p)
249+
250+
tau2 = a0_critical > 0 ? -log(rand(rng)) / a0_critical : Inf
251+
use_implicit = tau_im > alg.nstiff * tau_ex
252+
tau1 = use_implicit ? tau_im : tau_ex
253+
254+
if tau1 < 10 / sum(rate_cache; init=0.0)
255+
a0 = sum(rate_cache; init=0.0)
256+
if a0 > 0
257+
tau = -log(rand(rng)) / a0
258+
r = rand(rng) * a0
259+
cumsum_a = 0.0
260+
jc = 1
261+
for k in 1:numjumps
262+
cumsum_a += rate_cache[k]
263+
if cumsum_a > r
264+
jc = k
265+
break
266+
end
267+
end
268+
current_u .+= nu[:,jc]
269+
else
270+
tau = dt
271+
end
272+
else
273+
tau = min(tau1, tau2, dt)
274+
if tau == tau2
275+
if a0_critical > 0
276+
r = rand(rng) * a0_critical
277+
cumsum_a = 0.0
278+
jc = !isempty(J_critical) ? J_critical[1] : 1
279+
for k in J_critical
280+
cumsum_a += rate_cache[k]
281+
if cumsum_a > r
282+
jc = k
283+
break
284+
end
285+
end
286+
counts .= 0
287+
counts[jc] = 1
288+
if use_implicit && tau > tau_ex
289+
for k in J_ncr
290+
counts[k] = pois_rand(rng, rate_cache[k] * tau)
291+
end
292+
c(local_dc, current_u, p, tprev, counts, nothing)
293+
current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau)
294+
else
295+
for k in J_ncr
296+
counts[k] = pois_rand(rng, rate_cache[k] * tau)
297+
end
298+
c(local_dc, current_u, p, tprev, counts, nothing)
299+
current_u .+= local_dc
300+
end
301+
else
302+
tau = tau1
303+
if use_implicit
304+
for k in 1:numjumps
305+
counts[k] = pois_rand(rng, rate_cache[k] * tau)
306+
end
307+
c(local_dc, current_u, p, tprev, counts, nothing)
308+
current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau)
309+
else
310+
for k in 1:numjumps
311+
counts[k] = pois_rand(rng, rate_cache[k] * tau)
312+
end
313+
c(local_dc, current_u, p, tprev, counts, nothing)
314+
current_u .+= local_dc
315+
end
316+
end
317+
else
318+
counts .= 0
319+
if use_implicit
320+
for k in J_ncr
321+
counts[k] = pois_rand(rng, rate_cache[k] * tau)
322+
end
323+
c(local_dc, current_u, p, tprev, counts, nothing)
324+
current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau)
325+
else
326+
for k in J_ncr
327+
counts[k] = pois_rand(rng, rate_cache[k] * tau)
328+
end
329+
c(local_dc, current_u, p, tprev, counts, nothing)
330+
current_u .+= local_dc
331+
end
332+
end
333+
end
334+
335+
if any(current_u .< 0)
336+
tau1 /= 2
337+
continue
338+
end
339+
340+
u[i] = copy(current_u)
341+
end
342+
343+
sol = DiffEqBase.build_solution(prob, alg, t, u,
344+
calculate_error = false,
345+
interp = DiffEqBase.ConstantInterpolation(t, u))
346+
end
347+
53348
struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
54349
backend::Backend
55350
cpu_offload::Float64
@@ -63,14 +358,4 @@ function EnsembleGPUKernel()
63358
EnsembleGPUKernel(nothing, 0.0)
64359
end
65360

66-
# Define ImplicitTauLeaping algorithm
67-
struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm
68-
epsilon::Float64 # Error control parameter
69-
nc::Int # Critical reaction threshold
70-
nstiff::Int # Stiffness threshold multiplier
71-
delta::Float64 # Partial equilibrium threshold
72-
end
73-
74-
ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100, delta=0.05) = ImplicitTauLeaping(epsilon, nc, nstiff, delta)
75-
76361
export SimpleTauLeaping, EnsembleGPUKernel, ImplicitTauLeaping

test/regular_jumps.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,42 @@ jumps = JumpSet(rj)
3939
prob = DiscreteProblem([999, 1, 0], (0.0, 250.0))
4040
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
4141
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
42+
43+
# Parameters
44+
c1 = 1.0 # S1 -> 0
45+
c2 = 10.0 # S1 + S1 <- S2
46+
c3 = 1000.0 # S1 + S1 -> S2
47+
c4 = 0.1 # S2 -> S3
48+
p = (c1, c2, c3, c4)
49+
50+
# Propensity functions
51+
regular_rate_itu = (out, u, p, t) -> begin
52+
out[1] = p[1] * u[1] # S1 -> 0
53+
out[2] = p[2] * u[2] # S1 + S1 <- S2
54+
out[3] = p[3] * u[1] * (u[1] - 1) / 2 # S1 + S1 -> S2
55+
out[4] = p[4] * u[2] # S2 -> S3
56+
end
57+
58+
# State change function
59+
regular_c_itu = (dc, u, p, t, counts, mark) -> begin
60+
dc .= 0.0
61+
dc[1] = -counts[1] - 2 * counts[3] + 2 * counts[2] # S1: -decay - 2*forward + 2*backward
62+
dc[2] = counts[3] - counts[2] - counts[4] # S2: +forward - backward - decay
63+
dc[3] = counts[4] # S3: +decay
64+
end
65+
66+
# Initial condition
67+
u0 = [10000.0, 0.0, 0.0] # S1, S2, S3
68+
tspan = (0.0, 4.0)
69+
70+
# Define reversible reaction pairs (R2 and R3 are reversible: S1 + S1 <-> S2)
71+
reversible_pairs = [(2, 3)]
72+
73+
# Create JumpProblem with proper parameter passing
74+
prob_disc = DiscreteProblem(u0, tspan, p)
75+
rj = RegularJump(regular_rate_itu, regular_c_itu, 4)
76+
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
77+
78+
# Solve using ImplicitTauLeaping
79+
alg = ImplicitTauLeaping(epsilon=0.05, nc=10, nstiff=100, delta=0.05)
80+
sol = solve(jump_prob, alg; dt=0.01, reversible_pairs=reversible_pairs)

0 commit comments

Comments
 (0)