From 82306d99393bec2bf3b42f8080f976f7da20b920 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Wed, 3 Sep 2025 13:10:56 -0400 Subject: [PATCH] fix: comprehensive TauLeaping integration with DiscreteProblem for JumpProcesses MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This fixes the test failures in Interface1 and Interface2 by properly implementing TauLeaping algorithm support for DiscreteProblem when used with JumpProcesses. The issue was that TauLeaping algorithms were not properly integrated with DiscreteProblem, causing multiple dispatch failures. The fixes include: 1. **Method dispatch**: Extended __init signature to handle DiscreteProblem with StochasticDiffEqJumpAdaptiveAlgorithm 2. **Algorithm compatibility**: Added alg_compatible method for DiscreteProblem with TauLeaping algorithms 3. **Assertion fixes**: Updated assertions to accept both JumpProblem and DiscreteProblem 4. **Cache initialization**: Handle case where jump_rate_prototype is Nothing for discrete problems 5. **Initial timestep**: Skip noise function access for DiscreteProblem (no 'g' field) 6. **Noise handling**: Handle cases where jump process P is Nothing in perform_step! 7. **Adaptive stepping**: Add guards for adaptive stepping when P is Nothing These changes restore the callback saving functionality from PR #629 while properly fixing the underlying TauLeaping integration issues that were exposed by dependency updates. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- Project.toml | 2 +- src/alg_utils.jl | 3 +++ src/caches/tau_caches.jl | 10 ++++++++-- src/initdt.jl | 5 +++++ src/perform_step/tau_leaping.jl | 12 ++++++++---- src/solve.jl | 8 ++++---- 6 files changed, 29 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index a17f4997d..955c7e28e 100644 --- a/Project.toml +++ b/Project.toml @@ -59,7 +59,7 @@ Random = "1.6" RandomNumbers = "1.5.3" RecursiveArrayTools = "2, 3" Reexport = "0.2, 1.0" -SciMLBase = "2.115" +SciMLBase = "2.116" SciMLOperators = "0.2.9, 0.3, 0.4, 1" SparseArrays = "1.6" StaticArrays = "0.11, 0.12, 1.0" diff --git a/src/alg_utils.jl b/src/alg_utils.jl index e953acdbe..fcc6dadf8 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -345,6 +345,9 @@ function alg_compatible(prob::DiffEqBase.AbstractSDEProblem, end alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::BAOAB) = is_diagonal_noise(prob) +# TauLeaping algorithms are compatible with DiscreteProblem (for JumpProcesses integration) +alg_compatible(prob::DiscreteProblem, alg::StochasticDiffEqJumpAdaptiveAlgorithm) = true + function alg_compatible(prob::JumpProblem, alg::Union{StochasticDiffEqJumpAdaptiveAlgorithm, StochasticDiffEqJumpAlgorithm}) prob.prob isa DiscreteProblem diff --git a/src/caches/tau_caches.jl b/src/caches/tau_caches.jl index 6d7939d91..2ac89e055 100644 --- a/src/caches/tau_caches.jl +++ b/src/caches/tau_caches.jl @@ -20,8 +20,14 @@ function alg_cache(alg::TauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tmp = zero(u) - newrate = zero(jump_rate_prototype) - EEstcache = zero(jump_rate_prototype) + # Handle case where jump_rate_prototype is Nothing (for DiscreteProblem with TauLeaping) + if jump_rate_prototype === nothing + newrate = similar(u, 0) # Empty array for discrete problems without jumps + EEstcache = similar(u, 0) + else + newrate = zero(jump_rate_prototype) + EEstcache = zero(jump_rate_prototype) + end TauLeapingCache(u, uprev, tmp, newrate, EEstcache) end diff --git a/src/initdt.jl b/src/initdt.jl index c85129625..5e5fee9e1 100644 --- a/src/initdt.jl +++ b/src/initdt.jl @@ -5,6 +5,11 @@ function sde_determine_initdt(u0::uType, t::tType, tdir, dtmax, abstol, reltol, return tdir*dtmax/1e6 end + # Handle DiscreteProblem case (no noise function g) + if prob isa DiscreteProblem + return tdir*dtmax/1e6 + end + f = prob.f g = prob.g p = prob.p diff --git a/src/perform_step/tau_leaping.jl b/src/perform_step/tau_leaping.jl index d6f4c3e81..d23b1c4f3 100644 --- a/src/perform_step/tau_leaping.jl +++ b/src/perform_step/tau_leaping.jl @@ -1,9 +1,11 @@ @muladd function perform_step!(integrator, cache::TauLeapingConstantCache) @unpack t, dt, uprev, u, W, p, P, c = integrator - tmp = c(uprev, p, t, P.dW, nothing) + # Handle case where P is Nothing (for pure discrete problems) + dW = P === nothing ? nothing : P.dW + tmp = c(uprev, p, t, dW, nothing) integrator.u = uprev .+ tmp - if integrator.opts.adaptive + if integrator.opts.adaptive && P !== nothing if integrator.alg isa TauLeaping oldrate = P.cache.currate newrate = P.cache.rate(integrator.u, p, t+dt) @@ -22,10 +24,12 @@ end @muladd function perform_step!(integrator, cache::TauLeapingCache) @unpack t, dt, uprev, u, W, p, P, c = integrator @unpack tmp, newrate, EEstcache = cache - c(tmp, uprev, p, t, P.dW, nothing) + # Handle case where P is Nothing (for pure discrete problems) + dW = P === nothing ? nothing : P.dW + c(tmp, uprev, p, t, dW, nothing) @.. u = uprev + tmp - if integrator.opts.adaptive + if integrator.opts.adaptive && P !== nothing if integrator.alg isa TauLeaping oldrate = P.cache.currate P.cache.rate(newrate, u, p, t+dt) diff --git a/src/solve.jl b/src/solve.jl index 841d7476a..0acbf6cd4 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -18,8 +18,8 @@ concrete_prob(prob) = prob concrete_prob(prob::JumpProblem) = prob.prob function DiffEqBase.__init( - _prob::Union{DiffEqBase.AbstractRODEProblem, JumpProblem}, - alg::Union{AbstractRODEAlgorithm, AbstractSDEAlgorithm}, timeseries_init = typeof(_prob.u0)[], + _prob::Union{DiffEqBase.AbstractRODEProblem, JumpProblem, DiscreteProblem}, + alg::Union{AbstractRODEAlgorithm, AbstractSDEAlgorithm, StochasticDiffEqJumpAdaptiveAlgorithm}, timeseries_init = typeof(_prob.u0)[], ts_init = eltype(concrete_prob(_prob).tspan)[], ks_init = nothing, recompile::Type{Val{recompile_flag}} = Val{true}; @@ -527,8 +527,8 @@ function DiffEqBase.__init( elseif W.curt != t error("Starting time in the noise process is not the starting time of the simulation. The noise process should be re-initialized for repeated use") end - else # Only a jump problem - @assert _prob isa JumpProblem + else # Only a jump problem or discrete problem + @assert _prob isa Union{JumpProblem, DiscreteProblem} W = nothing end