Skip to content

Commit 82306d9

Browse files
fix: comprehensive TauLeaping integration with DiscreteProblem for JumpProcesses
This fixes the test failures in Interface1 and Interface2 by properly implementing TauLeaping algorithm support for DiscreteProblem when used with JumpProcesses. The issue was that TauLeaping algorithms were not properly integrated with DiscreteProblem, causing multiple dispatch failures. The fixes include: 1. **Method dispatch**: Extended __init signature to handle DiscreteProblem with StochasticDiffEqJumpAdaptiveAlgorithm 2. **Algorithm compatibility**: Added alg_compatible method for DiscreteProblem with TauLeaping algorithms 3. **Assertion fixes**: Updated assertions to accept both JumpProblem and DiscreteProblem 4. **Cache initialization**: Handle case where jump_rate_prototype is Nothing for discrete problems 5. **Initial timestep**: Skip noise function access for DiscreteProblem (no 'g' field) 6. **Noise handling**: Handle cases where jump process P is Nothing in perform_step! 7. **Adaptive stepping**: Add guards for adaptive stepping when P is Nothing These changes restore the callback saving functionality from PR SciML#629 while properly fixing the underlying TauLeaping integration issues that were exposed by dependency updates. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 0f7073f commit 82306d9

File tree

6 files changed

+29
-11
lines changed

6 files changed

+29
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Random = "1.6"
5959
RandomNumbers = "1.5.3"
6060
RecursiveArrayTools = "2, 3"
6161
Reexport = "0.2, 1.0"
62-
SciMLBase = "2.115"
62+
SciMLBase = "2.116"
6363
SciMLOperators = "0.2.9, 0.3, 0.4, 1"
6464
SparseArrays = "1.6"
6565
StaticArrays = "0.11, 0.12, 1.0"

src/alg_utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ function alg_compatible(prob::DiffEqBase.AbstractSDEProblem,
345345
end
346346
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::BAOAB) = is_diagonal_noise(prob)
347347

348+
# TauLeaping algorithms are compatible with DiscreteProblem (for JumpProcesses integration)
349+
alg_compatible(prob::DiscreteProblem, alg::StochasticDiffEqJumpAdaptiveAlgorithm) = true
350+
348351
function alg_compatible(prob::JumpProblem,
349352
alg::Union{StochasticDiffEqJumpAdaptiveAlgorithm, StochasticDiffEqJumpAlgorithm})
350353
prob.prob isa DiscreteProblem

src/caches/tau_caches.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@ function alg_cache(alg::TauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype,
2020
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt,
2121
::Type{Val{true}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
2222
tmp = zero(u)
23-
newrate = zero(jump_rate_prototype)
24-
EEstcache = zero(jump_rate_prototype)
23+
# Handle case where jump_rate_prototype is Nothing (for DiscreteProblem with TauLeaping)
24+
if jump_rate_prototype === nothing
25+
newrate = similar(u, 0) # Empty array for discrete problems without jumps
26+
EEstcache = similar(u, 0)
27+
else
28+
newrate = zero(jump_rate_prototype)
29+
EEstcache = zero(jump_rate_prototype)
30+
end
2531
TauLeapingCache(u, uprev, tmp, newrate, EEstcache)
2632
end
2733

src/initdt.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ function sde_determine_initdt(u0::uType, t::tType, tdir, dtmax, abstol, reltol,
55
return tdir*dtmax/1e6
66
end
77

8+
# Handle DiscreteProblem case (no noise function g)
9+
if prob isa DiscreteProblem
10+
return tdir*dtmax/1e6
11+
end
12+
813
f = prob.f
914
g = prob.g
1015
p = prob.p

src/perform_step/tau_leaping.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
@muladd function perform_step!(integrator, cache::TauLeapingConstantCache)
22
@unpack t, dt, uprev, u, W, p, P, c = integrator
3-
tmp = c(uprev, p, t, P.dW, nothing)
3+
# Handle case where P is Nothing (for pure discrete problems)
4+
dW = P === nothing ? nothing : P.dW
5+
tmp = c(uprev, p, t, dW, nothing)
46
integrator.u = uprev .+ tmp
57

6-
if integrator.opts.adaptive
8+
if integrator.opts.adaptive && P !== nothing
79
if integrator.alg isa TauLeaping
810
oldrate = P.cache.currate
911
newrate = P.cache.rate(integrator.u, p, t+dt)
@@ -22,10 +24,12 @@ end
2224
@muladd function perform_step!(integrator, cache::TauLeapingCache)
2325
@unpack t, dt, uprev, u, W, p, P, c = integrator
2426
@unpack tmp, newrate, EEstcache = cache
25-
c(tmp, uprev, p, t, P.dW, nothing)
27+
# Handle case where P is Nothing (for pure discrete problems)
28+
dW = P === nothing ? nothing : P.dW
29+
c(tmp, uprev, p, t, dW, nothing)
2630
@.. u = uprev + tmp
2731

28-
if integrator.opts.adaptive
32+
if integrator.opts.adaptive && P !== nothing
2933
if integrator.alg isa TauLeaping
3034
oldrate = P.cache.currate
3135
P.cache.rate(newrate, u, p, t+dt)

src/solve.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ concrete_prob(prob) = prob
1818
concrete_prob(prob::JumpProblem) = prob.prob
1919

2020
function DiffEqBase.__init(
21-
_prob::Union{DiffEqBase.AbstractRODEProblem, JumpProblem},
22-
alg::Union{AbstractRODEAlgorithm, AbstractSDEAlgorithm}, timeseries_init = typeof(_prob.u0)[],
21+
_prob::Union{DiffEqBase.AbstractRODEProblem, JumpProblem, DiscreteProblem},
22+
alg::Union{AbstractRODEAlgorithm, AbstractSDEAlgorithm, StochasticDiffEqJumpAdaptiveAlgorithm}, timeseries_init = typeof(_prob.u0)[],
2323
ts_init = eltype(concrete_prob(_prob).tspan)[],
2424
ks_init = nothing,
2525
recompile::Type{Val{recompile_flag}} = Val{true};
@@ -527,8 +527,8 @@ function DiffEqBase.__init(
527527
elseif W.curt != t
528528
error("Starting time in the noise process is not the starting time of the simulation. The noise process should be re-initialized for repeated use")
529529
end
530-
else # Only a jump problem
531-
@assert _prob isa JumpProblem
530+
else # Only a jump problem or discrete problem
531+
@assert _prob isa Union{JumpProblem, DiscreteProblem}
532532
W = nothing
533533
end
534534

0 commit comments

Comments
 (0)