Skip to content
Merged
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand Down Expand Up @@ -48,11 +49,12 @@ Logging = "1.6"
MuladdMacro = "0.2.1"
NLsolve = "4"
OrdinaryDiffEq = "6.87"
OrdinaryDiffEqCore = "1.12.1"
Random = "1.6"
RandomNumbers = "1.5.3"
RecursiveArrayTools = "2, 3"
Reexport = "0.2, 1.0"
SciMLBase = "2.59.2"
SciMLBase = "2.65"
SciMLOperators = "0.2.9, 0.3"
SparseArrays = "1.6"
SparseDiffTools = "2"
Expand Down
5 changes: 4 additions & 1 deletion src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ using DocStringExtensions
import DiffEqBase: step!, initialize!, DEAlgorithm,
AbstractSDEAlgorithm, AbstractRODEAlgorithm, DEIntegrator, AbstractDiffEqInterpolation,
DECache, AbstractSDEIntegrator, AbstractRODEIntegrator, AbstractContinuousCallback,
Tableau
Tableau, AbstractSDDEIntegrator

# Integrator Interface
import DiffEqBase: resize!,deleteat!,addat!,full_cache,user_cache,u_cache,du_cache,
Expand All @@ -58,6 +58,8 @@ using OrdinaryDiffEq: nlsolvefail, isnewton, set_new_W!, get_W, _vec, _reshape

using OrdinaryDiffEq: NLSolver

import OrdinaryDiffEqCore

if isdefined(OrdinaryDiffEq,:FastConvergence)
using OrdinaryDiffEq:
FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence
Expand Down Expand Up @@ -119,6 +121,7 @@ end
include("cache_utils.jl")
include("integrators/integrator_interface.jl")
include("iterator_interface.jl")
include("initialize_dae.jl")
include("solve.jl")
include("initdt.jl")
include("perform_step/low_order.jl")
Expand Down
3 changes: 3 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ SciMLBase.forwarddiffs_model(alg::Union{StochasticDiffEqNewtonAlgorithm,
StochasticDiffEqNewtonAdaptiveAlgorithm,StochasticDiffEqJumpNewtonAdaptiveAlgorithm,
StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm}) = OrdinaryDiffEq.alg_autodiff(alg)

# Required for initialization, because ODECore._initialize_dae! calls it during
# OverrideInit
OrdinaryDiffEqCore.has_autodiff(::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm,StochasticDiffEqJumpAlgorithm}) = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't true though, some of the implicit methods have the same autodiff args as the ode solver


isadaptive(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = false
isadaptive(alg::Union{StochasticDiffEqAdaptiveAlgorithm,StochasticDiffEqRODEAdaptiveAlgorithm,StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpDiffusionAdaptiveAlgorithm}) = true
Expand Down
13 changes: 13 additions & 0 deletions src/initialize_dae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
struct SDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end

function DiffEqBase.initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, initializealg = integrator.initializealg)
OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob)))
end

function OrdinaryDiffEqCore._initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, prob, ::SDEDefaultInit, isinplace)
if SciMLBase.has_initializeprob(prob.f)
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace)
elseif SciMLBase.__has_mass_matrix(prob.f)
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace)
end
end
3 changes: 2 additions & 1 deletion src/integrators/type.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs} <: AbstractSDEIntegrator{algType,IIP,uType,tType}
mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs,IA} <: AbstractSDEIntegrator{algType,IIP,uType,tType}
f::F4
g::F5
c::F6
Expand Down Expand Up @@ -43,4 +43,5 @@ mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenTy
qold::tTypeNoUnits
q11::tTypeNoUnits
stats::DiffEqBase.Stats
initializealg::IA
end
7 changes: 5 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ function DiffEqBase.__init(
userdata=nothing,
initialize_integrator=true,
seed = UInt64(0), alias_u0=false, alias_jumps = Threads.threadid()==1,
initializealg = SDEDefaultInit(),
kwargs...) where recompile_flag

prob = concrete_prob(_prob)
Expand Down Expand Up @@ -587,7 +588,8 @@ function DiffEqBase.__init(
uBottomEltype,tType,typeof(tdir),typeof(p),
typeof(eigen_est),QT,
uEltypeNoUnits,typeof(W),typeof(P),rateType,typeof(sol),typeof(cache),
FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants)}(
FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants),
typeof(initializealg)}(
f,g,c,noise,uprev,tprev,t,u,p,tType(dt),tType(dt),tType(dt),dtcache,tspan[2],tdir,
just_hit_tstop,do_error_check,isout,event_last_time,
vector_event_last_time,last_event_error,accept_step,
Expand All @@ -597,9 +599,10 @@ function DiffEqBase.__init(
alg,sol,
cache,callback_cache,tType(dt),W,P,rate_constants,
opts,iter,success_iter,eigen_est,EEst,q,
QT(qoldinit),q11,stats)
QT(qoldinit),q11,stats,initializealg)

if initialize_integrator
DiffEqBase.initialize_dae!(integrator)
initialize_callbacks!(integrator, initialize_save)
initialize!(integrator,integrator.cache)
save_start && alg isa Union{StochasticDiffEqCompositeAlgorithm,
Expand Down
Loading