Skip to content

Commit 2682289

Browse files
Move SDE default algorithm from DifferentialEquations.jl to StochasticDiffEq.jl
This PR moves the default SDE solver implementation from DifferentialEquations.jl to StochasticDiffEq.jl, following the pattern established in SciML/DelayDiffEq.jl#326 and SciML/DelayDiffEq.jl#334. ## Changes - Added `src/default_sde_alg.jl` containing the default algorithm selection logic - Implemented `__init` and `__solve` dispatches for `SDEProblem` with `Nothing` algorithm - Added `get_alg_hints` helper function for extracting algorithm hints from kwargs - Added comprehensive tests in `test/default_solver_test.jl` - Updated module to include the new default algorithm file ## Default Algorithm Behavior When no algorithm is specified, the solver now automatically selects: - SOSRI() as the standard default - RKMilCommute() for commutative noise - ImplicitRKMil() for stiff problems or non-identity mass matrices - RKMil() for Stratonovich interpretation - LambaEM() / LambaEulerHeun() for non-diagonal noise - ISSEM() / ImplicitEulerHeun() for stiff non-diagonal problems - SOSRA() / SKenCarp() for additive noise ## Test Plan - [x] Added tests verifying default solver dispatch - [x] Tests verify correct algorithm selection for various problem types - [x] All tests pass locally This is part of the ongoing effort to modularize DifferentialEquations.jl by moving default solvers to their respective packages. 🤖 Generated with [Claude Code](https://claude.ai/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent e23a84d commit 2682289

File tree

4 files changed

+145
-0
lines changed

4 files changed

+145
-0
lines changed

src/StochasticDiffEq.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ include("iterated_integrals.jl")
158158
include("SROCK_utils.jl")
159159
include("composite_algs.jl")
160160
include("weak_utils.jl")
161+
include("default_sde_alg.jl")
161162

162163
export StochasticDiffEqAlgorithm, StochasticDiffEqAdaptiveAlgorithm,
163164
StochasticCompositeAlgorithm

src/default_sde_alg.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Default algorithm selection for SDEs
2+
# Moved from DifferentialEquations.jl as part of modularization effort
3+
4+
using LinearAlgebra: I
5+
6+
# Helper function to extract alg_hints from keyword arguments
7+
function get_alg_hints(o)
8+
:alg_hints keys(o) ? alg_hints = o[:alg_hints] : alg_hints = Symbol[:auto]
9+
end
10+
11+
function default_algorithm(
12+
prob::DiffEqBase.AbstractSDEProblem{uType, tType, isinplace, ND};
13+
kwargs...) where {uType, tType, isinplace, ND}
14+
o = Dict{Symbol, Any}(kwargs)
15+
extra_kwargs = Any[]
16+
alg = SOSRI() # Standard default
17+
uEltype = eltype(prob.u0)
18+
19+
alg_hints = get_alg_hints(o)
20+
21+
if :commutative alg_hints
22+
alg = RKMilCommute()
23+
end
24+
25+
is_stiff = :stiff alg_hints
26+
is_stratonovich = :stratonovich alg_hints
27+
if is_stiff || prob.f.mass_matrix !== I
28+
alg = ImplicitRKMil(autodiff = false)
29+
end
30+
31+
if is_stratonovich
32+
if is_stiff || prob.f.mass_matrix !== I
33+
alg = ImplicitRKMil(autodiff = false,
34+
interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich)
35+
else
36+
alg = RKMil(interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich)
37+
end
38+
end
39+
40+
if prob.noise_rate_prototype != nothing || prob.noise != nothing
41+
if is_stratonovich
42+
if is_stiff || prob.f.mass_matrix !== I
43+
alg = ImplicitEulerHeun(autodiff = false)
44+
else
45+
alg = LambaEulerHeun()
46+
end
47+
else
48+
if is_stiff || prob.f.mass_matrix !== I
49+
alg = ISSEM(autodiff = false)
50+
else
51+
alg = LambaEM()
52+
end
53+
end
54+
end
55+
56+
if :additive alg_hints
57+
if is_stiff || prob.f.mass_matrix !== I
58+
alg = SKenCarp(autodiff = false)
59+
else
60+
alg = SOSRA()
61+
end
62+
end
63+
64+
# If adaptivity is not set and the tType is not a float, turn off adaptivity
65+
# Bad interaction with ForwardDiff
66+
#!(tType <: AbstractFloat) && (:adaptive ∉ keys(o)) && push!(extra_kwargs,:adaptive=>false)
67+
68+
alg, extra_kwargs
69+
end
70+
71+
# Dispatch for __init with Nothing algorithm - use default
72+
function DiffEqBase.__init(
73+
prob::DiffEqBase.AbstractSDEProblem, ::Nothing, args...; kwargs...)
74+
alg, extra_kwargs = default_algorithm(prob; kwargs...)
75+
DiffEqBase.__init(prob, alg, args...; extra_kwargs..., kwargs...)
76+
end
77+
78+
# Dispatch for __solve with Nothing algorithm - use default
79+
function DiffEqBase.__solve(
80+
prob::DiffEqBase.AbstractSDEProblem, ::Nothing, args...; kwargs...)
81+
alg, extra_kwargs = default_algorithm(prob; kwargs...)
82+
DiffEqBase.__solve(prob, alg, args...; extra_kwargs..., kwargs...)
83+
end

test/default_solver_test.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
using StochasticDiffEq, Test
2+
import SciMLBase
3+
using Random
4+
5+
# Additive SDE test problem
6+
f_additive(u, p, t) = @. p[2] / sqrt(1 + t) - u / (2 * (1 + t))
7+
σ_additive(u, p, t) = @. p[1] * p[2] / sqrt(1 + t)
8+
p = (0.1, 0.05)
9+
additive_analytic(u0, p, t, W) = @. u0 / sqrt(1 + t) + p[2] * (t + p[1] * W) / sqrt(1 + t)
10+
ff_additive = SDEFunction(f_additive, σ_additive, analytic = additive_analytic)
11+
prob_sde_additive = SDEProblem(ff_additive, σ_additive, 1.0, (0.0, 1.0), p)
12+
13+
Random.seed!(100)
14+
15+
# Test default (no algorithm specified) - should use SOSRI
16+
prob = prob_sde_additive
17+
sol = solve(prob, dt = 1 / 2^(3))
18+
@test sol.alg isa SOSRI
19+
20+
# Test with :additive hint - should use SOSRA
21+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:additive])
22+
@test sol.alg isa SOSRA
23+
24+
# Test with :stratonovich hint - should use RKMil with Stratonovich interpretation
25+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stratonovich])
26+
@test SciMLBase.alg_interpretation(sol.alg) ==
27+
SciMLBase.AlgorithmInterpretation.Stratonovich
28+
@test sol.alg isa RKMil
29+
30+
# Non-diagonal noise test problem
31+
f = (du, u, p, t) -> du .= 1.01u
32+
g = function (du, u, p, t)
33+
du[1, 1] = 0.3u[1]
34+
du[1, 2] = 0.6u[1]
35+
du[1, 3] = 0.9u[1]
36+
du[1, 4] = 0.12u[2]
37+
du[2, 1] = 1.2u[1]
38+
du[2, 2] = 0.2u[2]
39+
du[2, 3] = 0.3u[2]
40+
du[2, 4] = 1.8u[2]
41+
end
42+
prob = SDEProblem(f, g, ones(2), (0.0, 1.0), noise_rate_prototype = zeros(2, 4))
43+
44+
# Test default with non-diagonal noise - should use LambaEM
45+
sol = solve(prob, dt = 1 / 2^(3))
46+
@test sol.alg isa LambaEM
47+
48+
# Test with :stiff hint - should use ISSEM
49+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stiff])
50+
@test sol.alg isa ISSEM
51+
52+
# Test with :additive hint - should still use SOSRA (overrides non-diagonal)
53+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:additive])
54+
@test sol.alg isa SOSRA
55+
56+
# Test with :stratonovich hint - should use LambaEulerHeun
57+
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stratonovich])
58+
@test sol.alg isa LambaEulerHeun

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR")
1515

1616
@time begin
1717
if GROUP == "All" || GROUP == "Interface1"
18+
@time @safetestset "Default Solver Tests" begin
19+
include("default_solver_test.jl")
20+
end
1821
@time @safetestset "First Rand Tests" begin
1922
include("first_rand_test.jl")
2023
end

0 commit comments

Comments
 (0)