Skip to content

Commit 69664d8

Browse files
nonlinearsolver is implemented
1 parent 6b03d4f commit 69664d8

File tree

2 files changed

+90
-95
lines changed

2 files changed

+90
-95
lines changed

src/simple_regular_solve.jl

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262
ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) =
6363
ImplicitTauLeaping(epsilon, nc, nstiff, delta)
6464

65-
function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed=nothing)
65+
function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed = nothing)
6666
# Boilerplate setup
6767
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
6868
@assert isempty(jump_prob.jump_callback.discrete_callbacks)
@@ -155,7 +155,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed=
155155
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
156156
tau = min(tau, mu_term, sigma_term)
157157
end
158-
return max(tau, 1e-10) # Prevent zero or negative tau
158+
return max(tau, 1e-10)
159159
end
160160

161161
# Partial equilibrium check (Equation 13)
@@ -191,7 +191,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed=
191191
sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf
192192
tau = min(tau, mu_term, sigma_term)
193193
end
194-
return max(tau, 1e-10) # Prevent zero or negative tau
194+
return max(tau, 1e-10)
195195
end
196196

197197
# Identify critical reactions
@@ -213,46 +213,32 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed=
213213
return critical
214214
end
215215

216-
# Implicit tau-leaping step with Newton's method
216+
# Implicit tau-leaping step using NonlinearSolve
217217
function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p)
218-
u_new = copy(u_prev)
219-
rate_new = zeros(numjumps)
220-
tol = 1e-6
221-
max_iter = 50
222-
for iter in 1:max_iter
218+
# Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0
219+
function f(u_new, params)
220+
rate_new = similar(rate_cache, eltype(u_new))
223221
rate(rate_new, u_new, p, t_prev + tau)
224222
residual = u_new - u_prev
225223
for j in 1:numjumps
226224
residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j])
227225
end
228-
if norm(residual) < tol
229-
break
230-
end
231-
# Improved Jacobian approximation
232-
J = Diagonal(ones(length(u_new)))
233-
for j in 1:numjumps
234-
for i in 1:length(u_new)
235-
if rate_new[j] > 0 && u_new[i] > 0
236-
# Scale derivative to prevent overflow
237-
J[i, i] += nu[i, j] * tau * min(rate_new[j] / u_new[i], 1e3)
238-
end
239-
end
240-
end
241-
# Check for singular or ill-conditioned Jacobian
242-
if any(abs.(diag(J)) .< 1e-10)
243-
return u_prev # Revert to previous state if Jacobian is singular
244-
end
245-
delta_u = J \ residual
246-
# Limit step size to prevent overflow
247-
delta_u = clamp.(delta_u, -1e3, 1e3)
248-
u_new -= delta_u
249-
u_new = max.(u_new, 0.0)
250-
# Check for numerical overflow
251-
if any(isnan.(u_new)) || any(isinf.(u_new))
252-
return u_prev
253-
end
226+
return residual
254227
end
255-
return round.(Int, max.(u_new, 0.0))
228+
229+
# Initial guess
230+
u_new = copy(u_prev)
231+
232+
# Solve the nonlinear system
233+
prob = NonlinearProblem(f, u_new, nothing)
234+
sol = solve(prob, NewtonRaphson())
235+
236+
# Check for convergence and numerical stability
237+
if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u))
238+
return round.(Int, max.(u_prev, 0.0)) # Revert to previous state
239+
end
240+
241+
return round.(Int, max.(sol.u, 0.0))
256242
end
257243

258244
# Down-shifting condition (Equation 19)
@@ -375,9 +361,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed=
375361

376362
# Build solution
377363
sol = DiffEqBase.build_solution(prob, alg, t, u,
378-
calculate_error=false,
379-
interp=DiffEqBase.ConstantInterpolation(t, u))
380-
return sol
364+
calculate_error = false,
365+
interp = DiffEqBase.ConstantInterpolation(t, u))
381366
end
382367

383368
struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm

test/regular_jumps.jl

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,82 @@
11
using JumpProcesses, DiffEqBase
2-
using Test, LinearAlgebra
2+
using Test, LinearAlgebra, Statistics
33
using StableRNGs
44
rng = StableRNG(12345)
55

6-
function regular_rate(out, u, p, t)
7-
out[1] = (0.1 / 1000.0) * u[1] * u[2]
8-
out[2] = 0.01u[2]
9-
end
6+
Nsims = 8000
107

11-
function regular_c(dc, u, p, t, mark)
12-
dc[1, 1] = -1
13-
dc[2, 1] = 1
14-
dc[2, 2] = -1
15-
dc[3, 2] = 1
16-
end
8+
# SIR model with influx
9+
let
10+
β = 0.1 / 1000.0
11+
ν = 0.01
12+
influx_rate = 1.0
13+
p = (β, ν, influx_rate)
1714

18-
dc = zeros(3, 2)
15+
regular_rate = (out, u, p, t) -> begin
16+
out[1] = p[1] * u[1] * u[2] # β*S*I (infection)
17+
out[2] = p[2] * u[2] # ν*I (recovery)
18+
out[3] = p[3] # influx_rate
19+
end
1920

20-
rj = RegularJump(regular_rate, regular_c, dc; constant_c = true)
21-
jumps = JumpSet(rj)
21+
regular_c = (dc, u, p, t, counts, mark) -> begin
22+
dc .= 0.0
23+
dc[1] = -counts[1] + counts[3] # S: -infection + influx
24+
dc[2] = counts[1] - counts[2] # I: +infection - recovery
25+
dc[3] = counts[2] # R: +recovery
26+
end
2227

23-
prob = DiscreteProblem([999.0, 1.0, 0.0], (0.0, 250.0))
24-
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
25-
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
28+
u0 = [999.0, 10.0, 0.0] # S, I, R
29+
tspan = (0.0, 250.0)
2630

27-
const _dc = zeros(3, 2)
28-
dc[1, 1] = -1
29-
dc[2, 1] = 1
30-
dc[2, 2] = -1
31-
dc[3, 2] = 1
31+
prob_disc = DiscreteProblem(u0, tspan, p)
32+
rj = RegularJump(regular_rate, regular_c, 3)
33+
jump_prob = JumpProblem(prob_disc, Direct(), rj)
3234

33-
function regular_c(du, u, p, t, counts, mark)
34-
mul!(du, dc, counts)
35-
end
35+
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
36+
mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims)
3637

37-
rj = RegularJump(regular_rate, regular_c, 2)
38-
jumps = JumpSet(rj)
39-
prob = DiscreteProblem([999, 1, 0], (0.0, 250.0))
40-
jump_prob = JumpProblem(prob, Direct(), rj; rng = rng)
41-
sol = solve(jump_prob, SimpleTauLeaping(); dt = 1.0)
42-
43-
# Decaying Dimerization Model
44-
# Parameters
45-
c1 = 1.0 # S1 -> 0
46-
c2 = 10.0 # S1 + S1 <- S2
47-
c3 = 1000.0 # S1 + S1 -> S2
48-
c4 = 0.1 # S2 -> S3
49-
p_dim = (c1, c2, c3, c4)
50-
51-
regular_rate_dim = (out, u, p, t) -> begin
52-
out[1] = p[1] * u[1] # S1 -> 0
53-
out[2] = p[2] * u[2] # S1 + S1 <- S2
54-
out[3] = p[3] * u[1] * (u[1] - 1) / 2 # S1 + S1 -> S2
55-
out[4] = p[4] * u[2] # S2 -> S3
56-
end
38+
sol = solve(EnsembleProblem(jump_prob), ImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
39+
mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims)
5740

58-
regular_c_dim = (du, u, p, t, counts, mark) -> begin
59-
du .= 0.0
60-
du[1] = -counts[1] - 2 * counts[3] + 2 * counts[2] # S1: -decay - 2*forward + 2*backward
61-
du[2] = counts[3] - counts[2] - counts[4] # S2: +forward - backward - decay
62-
du[3] = counts[4] # S3: +decay
41+
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
6342
end
6443

65-
u0_dim = [10000.0, 0.0, 0.0] # S1, S2, S3
66-
tspan_dim = (0.0, 4.0)
6744

68-
prob_disc_dim = DiscreteProblem(u0_dim, tspan_dim, p_dim)
69-
rj_dim = RegularJump(regular_rate_dim, regular_c_dim, 4)
70-
jump_prob_dim = JumpProblem(prob_disc_dim, Direct(), rj_dim; rng=rng)
45+
# SEIR model with exposed compartment
46+
let
47+
β = 0.3 / 1000.0
48+
σ = 0.2
49+
ν = 0.01
50+
p = (β, σ, ν)
51+
52+
regular_rate = (out, u, p, t) -> begin
53+
out[1] = p[1] * u[1] * u[3] # β*S*I (infection)
54+
out[2] = p[2] * u[2] # σ*E (progression)
55+
out[3] = p[3] * u[3] # ν*I (recovery)
56+
end
57+
58+
regular_c = (dc, u, p, t, counts, mark) -> begin
59+
dc .= 0.0
60+
dc[1] = -counts[1] # S: -infection
61+
dc[2] = counts[1] - counts[2] # E: +infection - progression
62+
dc[3] = counts[2] - counts[3] # I: +progression - recovery
63+
dc[4] = counts[3] # R: +recovery
64+
end
7165

72-
sol = solve(jump_prob_dim, ImplicitTauLeaping())
66+
# Initial state
67+
u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R
68+
tspan = (0.0, 250.0)
69+
70+
# Create JumpProblem
71+
prob_disc = DiscreteProblem(u0, tspan, p)
72+
rj = RegularJump(regular_rate, regular_c, 3)
73+
jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345))
74+
75+
sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0)
76+
mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims)
77+
78+
sol = solve(EnsembleProblem(jump_prob), ImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims)
79+
mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims)
80+
81+
@test isapprox(mean_simple, mean_implicit, rtol=0.05)
82+
end

0 commit comments

Comments
 (0)