diff --git a/Project.toml b/Project.toml index a17f4997d..955c7e28e 100644 --- a/Project.toml +++ b/Project.toml @@ -59,7 +59,7 @@ Random = "1.6" RandomNumbers = "1.5.3" RecursiveArrayTools = "2, 3" Reexport = "0.2, 1.0" -SciMLBase = "2.115" +SciMLBase = "2.116" SciMLOperators = "0.2.9, 0.3, 0.4, 1" SparseArrays = "1.6" StaticArrays = "0.11, 0.12, 1.0" diff --git a/src/alg_utils.jl b/src/alg_utils.jl index e953acdbe..fcc6dadf8 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -345,6 +345,9 @@ function alg_compatible(prob::DiffEqBase.AbstractSDEProblem, end alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::BAOAB) = is_diagonal_noise(prob) +# TauLeaping algorithms are compatible with DiscreteProblem (for JumpProcesses integration) +alg_compatible(prob::DiscreteProblem, alg::StochasticDiffEqJumpAdaptiveAlgorithm) = true + function alg_compatible(prob::JumpProblem, alg::Union{StochasticDiffEqJumpAdaptiveAlgorithm, StochasticDiffEqJumpAlgorithm}) prob.prob isa DiscreteProblem diff --git a/src/caches/tau_caches.jl b/src/caches/tau_caches.jl index 6d7939d91..2ac89e055 100644 --- a/src/caches/tau_caches.jl +++ b/src/caches/tau_caches.jl @@ -20,8 +20,14 @@ function alg_cache(alg::TauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tmp = zero(u) - newrate = zero(jump_rate_prototype) - EEstcache = zero(jump_rate_prototype) + # Handle case where jump_rate_prototype is Nothing (for DiscreteProblem with TauLeaping) + if jump_rate_prototype === nothing + newrate = similar(u, 0) # Empty array for discrete problems without jumps + EEstcache = similar(u, 0) + else + newrate = zero(jump_rate_prototype) + EEstcache = zero(jump_rate_prototype) + end TauLeapingCache(u, uprev, tmp, newrate, EEstcache) end diff --git a/src/initdt.jl b/src/initdt.jl index c85129625..5e5fee9e1 100644 --- a/src/initdt.jl +++ b/src/initdt.jl @@ -5,6 +5,11 @@ function sde_determine_initdt(u0::uType, t::tType, tdir, dtmax, abstol, reltol, return tdir*dtmax/1e6 end + # Handle DiscreteProblem case (no noise function g) + if prob isa DiscreteProblem + return tdir*dtmax/1e6 + end + f = prob.f g = prob.g p = prob.p diff --git a/src/perform_step/tau_leaping.jl b/src/perform_step/tau_leaping.jl index d6f4c3e81..d23b1c4f3 100644 --- a/src/perform_step/tau_leaping.jl +++ b/src/perform_step/tau_leaping.jl @@ -1,9 +1,11 @@ @muladd function perform_step!(integrator, cache::TauLeapingConstantCache) @unpack t, dt, uprev, u, W, p, P, c = integrator - tmp = c(uprev, p, t, P.dW, nothing) + # Handle case where P is Nothing (for pure discrete problems) + dW = P === nothing ? nothing : P.dW + tmp = c(uprev, p, t, dW, nothing) integrator.u = uprev .+ tmp - if integrator.opts.adaptive + if integrator.opts.adaptive && P !== nothing if integrator.alg isa TauLeaping oldrate = P.cache.currate newrate = P.cache.rate(integrator.u, p, t+dt) @@ -22,10 +24,12 @@ end @muladd function perform_step!(integrator, cache::TauLeapingCache) @unpack t, dt, uprev, u, W, p, P, c = integrator @unpack tmp, newrate, EEstcache = cache - c(tmp, uprev, p, t, P.dW, nothing) + # Handle case where P is Nothing (for pure discrete problems) + dW = P === nothing ? nothing : P.dW + c(tmp, uprev, p, t, dW, nothing) @.. u = uprev + tmp - if integrator.opts.adaptive + if integrator.opts.adaptive && P !== nothing if integrator.alg isa TauLeaping oldrate = P.cache.currate P.cache.rate(newrate, u, p, t+dt) diff --git a/src/solve.jl b/src/solve.jl index 841d7476a..0acbf6cd4 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -18,8 +18,8 @@ concrete_prob(prob) = prob concrete_prob(prob::JumpProblem) = prob.prob function DiffEqBase.__init( - _prob::Union{DiffEqBase.AbstractRODEProblem, JumpProblem}, - alg::Union{AbstractRODEAlgorithm, AbstractSDEAlgorithm}, timeseries_init = typeof(_prob.u0)[], + _prob::Union{DiffEqBase.AbstractRODEProblem, JumpProblem, DiscreteProblem}, + alg::Union{AbstractRODEAlgorithm, AbstractSDEAlgorithm, StochasticDiffEqJumpAdaptiveAlgorithm}, timeseries_init = typeof(_prob.u0)[], ts_init = eltype(concrete_prob(_prob).tspan)[], ks_init = nothing, recompile::Type{Val{recompile_flag}} = Val{true}; @@ -527,8 +527,8 @@ function DiffEqBase.__init( elseif W.curt != t error("Starting time in the noise process is not the starting time of the simulation. The noise process should be re-initialized for repeated use") end - else # Only a jump problem - @assert _prob isa JumpProblem + else # Only a jump problem or discrete problem + @assert _prob isa Union{JumpProblem, DiscreteProblem} W = nothing end