Skip to content

Commit 2c9bb29

Browse files
SimpleAdaptiveTauLeaping is done
1 parent 0f1bcf5 commit 2c9bb29

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

src/simple_regular_solve.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,60 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
5050
interp = DiffEqBase.ConstantInterpolation(t, u))
5151
end
5252

53+
struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm
54+
epsilon::Float64 # Error control parameter
55+
end
56+
57+
SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon)
58+
59+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing)
60+
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
61+
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
62+
prob = jump_prob.prob
63+
rng = DEFAULT_RNG
64+
(seed !== nothing) && seed!(rng, seed)
65+
66+
rj = jump_prob.regular_jump
67+
rate = rj.rate
68+
numjumps = rj.numjumps
69+
c = rj.c
70+
u0 = copy(prob.u0)
71+
tspan = prob.tspan
72+
p = prob.p
73+
74+
u = [copy(u0)]
75+
t = [tspan[1]]
76+
rate_cache = zeros(Float64, numjumps)
77+
counts = zeros(Int, numjumps)
78+
du = similar(u0)
79+
t_end = tspan[2]
80+
epsilon = alg.epsilon
81+
82+
nu = compute_stoichiometry(c, u0, numjumps, p, t[1])
83+
84+
while t[end] < t_end
85+
u_prev = u[end]
86+
t_prev = t[end]
87+
rate(rate_cache, u_prev, p, t_prev)
88+
tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate)
89+
tau = min(tau, t_end - t_prev)
90+
counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0))
91+
c(du, u_prev, p, t_prev, counts, nothing)
92+
u_new = u_prev + du
93+
if any(u_new .< 0)
94+
tau /= 2
95+
continue
96+
end
97+
push!(u, u_new)
98+
push!(t, t_prev + tau)
99+
end
100+
101+
sol = DiffEqBase.build_solution(prob, alg, t, u,
102+
calculate_error=false,
103+
interp=DiffEqBase.ConstantInterpolation(t, u))
104+
return sol
105+
end
106+
53107
struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm
54108
epsilon::Float64 # Error control parameter
55109
nc::Int # Critical reaction threshold
@@ -379,4 +433,4 @@ function EnsembleGPUKernel()
379433
EnsembleGPUKernel(nothing, 0.0)
380434
end
381435

382-
export SimpleTauLeaping, EnsembleGPUKernel, SimpleImplicitTauLeaping
436+
export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping

test/regular_jumps.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ let
3838
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
3939
mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims)
4040

41+
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
42+
mean_adaptive = mean(sol.u[i][1,end] for i in 1:Nsims)
43+
4144
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
45+
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
4246
end
4347

4448

@@ -78,5 +82,9 @@ let
7882
sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
7983
mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims)
8084

85+
sol = solve(EnsembleProblem(jump_prob), SimpleAdaptiveTauLeaping(), EnsembleSerial(); trajectories = Nsims)
86+
mean_adaptive = mean(sol.u[i][end,end] for i in 1:Nsims)
87+
8188
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
89+
@test isapprox(mean_simple, mean_adaptive, rtol=0.05)
8290
end

0 commit comments

Comments
 (0)