Skip to content

Add default SDE algorithm with type-stable approach #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down Expand Up @@ -41,6 +42,7 @@ DataStructures = "0.18, 0.19"
DiffEqBase = "6.154"
DiffEqNoiseProcess = "5.13"
DocStringExtensions = "0.8, 0.9"
EnumX = "1"
FastPower = "1"
FiniteDiff = "2"
ForwardDiff = "0.10.3, 1"
Expand Down
2 changes: 2 additions & 0 deletions src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Reexport
@reexport using DiffEqBase

import ADTypes
using EnumX

import OrdinaryDiffEqCore
import OrdinaryDiffEqCore: default_controller, isstandard, ispredictive,
Expand Down Expand Up @@ -158,6 +159,7 @@ include("iterated_integrals.jl")
include("SROCK_utils.jl")
include("composite_algs.jl")
include("weak_utils.jl")
include("default_alg.jl")

export StochasticDiffEqAlgorithm, StochasticDiffEqAdaptiveAlgorithm,
StochasticCompositeAlgorithm
Expand Down
158 changes: 158 additions & 0 deletions src/default_alg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
Default algorithm selection for StochasticDiffEq.jl
Based on the logic from DifferentialEquations.jl but using a type-stable approach
similar to OrdinaryDiffEq.jl's DefaultODEAlgorithm
"""

EnumX.@enumx DefaultSDESolverChoice begin
SOSRI = 1
SOSRA = 2
RKMilCommute = 3
ImplicitRKMil = 4
SKenCarp = 5
LambaEM = 6
ISSEM = 7
LambaEulerHeun = 8
ImplicitEulerHeun = 9
end

const NUM_NONSTIFF_SDE = 3 # SOSRI, SOSRA, RKMilCommute
const NUM_STIFF_SDE = 2 # ImplicitRKMil, SKenCarp

"""
DefaultSDEAlgorithm(; kwargs...)

Constructs a default SDE algorithm that automatically switches between different
solvers based on problem characteristics. This provides a type-stable default
algorithm similar to OrdinaryDiffEq.jl's DefaultODEAlgorithm.

The algorithm selection is based on:

- Stiffness detection
- Noise characteristics (additive, commutative, general)
- Problem interpretation (Ito vs Stratonovich)

## Keyword Arguments

- `lazy::Bool=true`: Whether to use lazy interpolants
- `stiffalgfirst::Bool=false`: Whether to start with the stiff algorithm
- `autodiff::Union{Bool,ADTypes.AbstractADType}=true`: Automatic differentiation backend
- `kwargs...`: Additional keyword arguments passed to the algorithms
"""
function DefaultSDEAlgorithm(; lazy = true, stiffalgfirst = false,
autodiff = true, kwargs...)
# For now, we use a simpler approach with just two algorithms
# This can be expanded later to include more sophisticated selection
nonstiff = SOSRI()
stiff = ImplicitRKMil(autodiff = autodiff)

AutoAlgSwitch(nonstiff, stiff; stiffalgfirst = stiffalgfirst)
end

"""
DefaultAdaptiveSDEAlgorithm(; kwargs...)

Constructs a more sophisticated default SDE algorithm that adapts based on
problem characteristics including noise type and algorithm hints.

## Keyword Arguments

- `alg_hints::Vector{Symbol}=Symbol[]`: Algorithm hints like :additive, :commutative, :stiff, :stratonovich
- `autodiff::Union{Bool,ADTypes.AbstractADType}=true`: Automatic differentiation backend
- `kwargs...`: Additional keyword arguments passed to the algorithms
"""
function DefaultAdaptiveSDEAlgorithm(; alg_hints = Symbol[],
autodiff = true, kwargs...)
is_stiff = :stiff ∈ alg_hints
is_stratonovich = :stratonovich ∈ alg_hints
is_additive = :additive ∈ alg_hints
is_commutative = :commutative ∈ alg_hints

# Select appropriate algorithms based on hints
if is_additive
nonstiff = SOSRA()
stiff = ISSEM(autodiff = autodiff) # Use ISSEM for stiff additive noise
elseif is_commutative
nonstiff = RKMilCommute()
stiff = ImplicitRKMil(autodiff = autodiff)
else
nonstiff = SOSRI()
if is_stratonovich
stiff = ImplicitRKMil(autodiff = autodiff,
interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich)
else
stiff = ImplicitRKMil(autodiff = autodiff)
end
end

AutoAlgSwitch(nonstiff, stiff; stiffalgfirst = is_stiff)
end

"""
default_sde_alg_choice(prob::SDEProblem)

Determines the default algorithm choice for an SDE problem based on its characteristics.
Returns algorithm hints that can be used to select appropriate solvers.
"""
function default_sde_alg_choice(prob::SDEProblem)
alg_hints = Symbol[]

# Check for mass matrix (implies stiffness)
if hasproperty(prob.f, :mass_matrix) && prob.f.mass_matrix !== I
push!(alg_hints, :stiff)
end

# Check for additive noise (diagonal noise with no state dependence)
if prob.noise === nothing || isa(prob.noise, DiffEqNoiseProcess.NoiseProcess)
if DiffEqBase.is_diagonal_noise(prob)
# Simple heuristic for additive noise
push!(alg_hints, :additive)
end
end

# Check problem interpretation
if hasproperty(prob, :interpretation) &&
prob.interpretation == SciMLBase.AlgorithmInterpretation.Stratonovich
push!(alg_hints, :stratonovich)
end

return alg_hints
end

# Hook into the solve interface
function SciMLBase.__init(prob::SDEProblem, ::Nothing, args...; kwargs...)
alg_hints = default_sde_alg_choice(prob)
alg = if isempty(alg_hints)
DefaultSDEAlgorithm(; kwargs...)
else
DefaultAdaptiveSDEAlgorithm(; alg_hints = alg_hints, kwargs...)
end
SciMLBase.__init(prob, alg, args...; kwargs...)
end

function SciMLBase.__solve(prob::SDEProblem, ::Nothing, args...; kwargs...)
alg_hints = default_sde_alg_choice(prob)
alg = if isempty(alg_hints)
DefaultSDEAlgorithm(; kwargs...)
else
DefaultAdaptiveSDEAlgorithm(; alg_hints = alg_hints, kwargs...)
end
SciMLBase.__solve(prob, alg, args...; kwargs...)
end

# Mark default algorithms for special handling
function isdefaultalg(alg::StochasticCompositeAlgorithm)
# Check if it's one of our default algorithm combinations
if isa(alg.algs, Tuple{SOSRI, ImplicitRKMil}) ||
isa(alg.algs, Tuple{SOSRA, ISSEM}) ||
isa(alg.algs, Tuple{RKMilCommute, ImplicitRKMil})
return true
end
return false
end

# Also provide a fallback for other algorithms
isdefaultalg(alg) = false

export DefaultSDEAlgorithm, DefaultAdaptiveSDEAlgorithm, default_sde_alg_choice,
isdefaultalg
128 changes: 128 additions & 0 deletions test/default_alg_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
using StochasticDiffEq, DiffEqNoiseProcess, Random, LinearAlgebra
using Test

Random.seed!(100)

# Simple SDE problem for testing
f(u, p, t) = 1.01 * u
g(u, p, t) = 0.87 * u
u0 = 1.0
tspan = (0.0, 1.0)

@testset "Default SDE Algorithm Tests" begin
@testset "Basic DefaultSDEAlgorithm" begin
prob = SDEProblem(f, g, u0, tspan)

# Test with explicit DefaultSDEAlgorithm
alg = DefaultSDEAlgorithm()
sol = solve(prob, alg)
@test sol.retcode == ReturnCode.Success
@test length(sol.t) > 2

# Test with nothing (should use default)
sol2 = solve(prob, nothing)
@test sol2.retcode == ReturnCode.Success
@test length(sol2.t) > 2

# Test solve without algorithm (should use default)
sol3 = solve(prob)
@test sol3.retcode == ReturnCode.Success
@test length(sol3.t) > 2
end

@testset "DefaultAdaptiveSDEAlgorithm with hints" begin
prob = SDEProblem(f, g, u0, tspan)

# Test with additive noise hint
alg = DefaultAdaptiveSDEAlgorithm(alg_hints = [:additive])
sol = solve(prob, alg)
@test sol.retcode == ReturnCode.Success

# Test with stiff hint
alg = DefaultAdaptiveSDEAlgorithm(alg_hints = [:stiff])
sol = solve(prob, alg)
@test sol.retcode == ReturnCode.Success

# Test with commutative hint
alg = DefaultAdaptiveSDEAlgorithm(alg_hints = [:commutative])
sol = solve(prob, alg)
@test sol.retcode == ReturnCode.Success
end

@testset "Additive noise problems" begin
# Additive noise SDE
f_add(u, p, t) = -0.5 * u
g_add(u, p, t) = 0.1
prob_add = SDEProblem(f_add, g_add, u0, tspan)

sol = solve(prob_add, nothing)
@test sol.retcode == ReturnCode.Success
@test length(sol.t) > 2
end

@testset "System of SDEs" begin
# System of SDEs
f_sys(du, u, p, t) = begin
du[1] = 1.01 * u[1]
du[2] = -0.5 * u[2]
end
g_sys(du, u, p, t) = begin
du[1] = 0.3 * u[1]
du[2] = 0.2 * u[2]
end
u0_sys = [1.0, 2.0]
prob_sys = SDEProblem(f_sys, g_sys, u0_sys, tspan)

sol = solve(prob_sys, nothing)
@test sol.retcode == ReturnCode.Success
@test length(sol.t) > 2
@test size(sol.u[end]) == size(u0_sys)
end

@testset "Stiff SDE problems" begin
# Stiff problem
f_stiff(u, p, t) = -100.0 * u
g_stiff(u, p, t) = 0.1 * u
prob_stiff = SDEProblem(f_stiff, g_stiff, u0, (0.0, 0.1))

alg = DefaultAdaptiveSDEAlgorithm(alg_hints = [:stiff])
sol = solve(prob_stiff, alg)
@test sol.retcode == ReturnCode.Success
end

@testset "Algorithm selection logic" begin
prob = SDEProblem(f, g, u0, tspan)

# Test default_sde_alg_choice
hints = default_sde_alg_choice(prob)
@test isa(hints, Vector{Symbol})

# Test with mass matrix (should add :stiff hint)
function f_mass(du, u, p, t)
du[1] = 1.01 * u[1]
du[2] = -0.5 * u[2]
end
function g_mass(du, u, p, t)
du[1] = 0.3 * u[1]
du[2] = 0.2 * u[2]
end
M = [1.0 0.0; 0.0 2.0]
f_mass_func = ODEFunction(f_mass, mass_matrix = M)
prob_mass = SDEProblem(f_mass_func, g_mass, [1.0, 2.0], tspan)

hints_mass = default_sde_alg_choice(prob_mass)
@test :stiff in hints_mass
end

@testset "isdefaultalg check" begin
alg1 = DefaultSDEAlgorithm()
@test StochasticDiffEq.isdefaultalg(alg1) == true

alg2 = DefaultAdaptiveSDEAlgorithm(alg_hints = [:additive])
@test StochasticDiffEq.isdefaultalg(alg2) == true

# Non-default algorithm
alg3 = EM()
@test StochasticDiffEq.isdefaultalg(alg3) == false
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR")
@time @safetestset "Composite Tests" begin
include("composite_algorithm_test.jl")
end
@time @safetestset "Default Algorithm Tests" begin
include("default_alg_test.jl")
end
@time @safetestset "Events Tests" begin
include("events_test.jl")
end
Expand Down
Loading