diff --git a/Project.toml b/Project.toml index d71f7fd4..f25b485c 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" FastPower = "a4df4552-cc26-4903-aec0-212e50a0e84b" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -41,6 +42,7 @@ DataStructures = "0.18, 0.19" DiffEqBase = "6.154" DiffEqNoiseProcess = "5.13" DocStringExtensions = "0.8, 0.9" +EnumX = "1" FastPower = "1" FiniteDiff = "2" ForwardDiff = "0.10.3, 1" diff --git a/src/StochasticDiffEq.jl b/src/StochasticDiffEq.jl index 80904221..2f6a4acc 100644 --- a/src/StochasticDiffEq.jl +++ b/src/StochasticDiffEq.jl @@ -10,6 +10,7 @@ using Reexport @reexport using DiffEqBase import ADTypes +using EnumX import OrdinaryDiffEqCore import OrdinaryDiffEqCore: default_controller, isstandard, ispredictive, @@ -158,6 +159,7 @@ include("iterated_integrals.jl") include("SROCK_utils.jl") include("composite_algs.jl") include("weak_utils.jl") +include("default_alg.jl") export StochasticDiffEqAlgorithm, StochasticDiffEqAdaptiveAlgorithm, StochasticCompositeAlgorithm diff --git a/src/default_alg.jl b/src/default_alg.jl new file mode 100644 index 00000000..56ee6b66 --- /dev/null +++ b/src/default_alg.jl @@ -0,0 +1,158 @@ +""" +Default algorithm selection for StochasticDiffEq.jl +Based on the logic from DifferentialEquations.jl but using a type-stable approach +similar to OrdinaryDiffEq.jl's DefaultODEAlgorithm +""" + +EnumX.@enumx DefaultSDESolverChoice begin + SOSRI = 1 + SOSRA = 2 + RKMilCommute = 3 + ImplicitRKMil = 4 + SKenCarp = 5 + LambaEM = 6 + ISSEM = 7 + LambaEulerHeun = 8 + ImplicitEulerHeun = 9 +end + +const NUM_NONSTIFF_SDE = 3 # SOSRI, SOSRA, RKMilCommute +const NUM_STIFF_SDE = 2 # ImplicitRKMil, SKenCarp + +""" + DefaultSDEAlgorithm(; kwargs...) + +Constructs a default SDE algorithm that automatically switches between different +solvers based on problem characteristics. This provides a type-stable default +algorithm similar to OrdinaryDiffEq.jl's DefaultODEAlgorithm. + +The algorithm selection is based on: + + - Stiffness detection + - Noise characteristics (additive, commutative, general) + - Problem interpretation (Ito vs Stratonovich) + +## Keyword Arguments + + - `lazy::Bool=true`: Whether to use lazy interpolants + - `stiffalgfirst::Bool=false`: Whether to start with the stiff algorithm + - `autodiff::Union{Bool,ADTypes.AbstractADType}=true`: Automatic differentiation backend + - `kwargs...`: Additional keyword arguments passed to the algorithms +""" +function DefaultSDEAlgorithm(; lazy = true, stiffalgfirst = false, + autodiff = true, kwargs...) + # For now, we use a simpler approach with just two algorithms + # This can be expanded later to include more sophisticated selection + nonstiff = SOSRI() + stiff = ImplicitRKMil(autodiff = autodiff) + + AutoAlgSwitch(nonstiff, stiff; stiffalgfirst = stiffalgfirst) +end + +""" + DefaultAdaptiveSDEAlgorithm(; kwargs...) + +Constructs a more sophisticated default SDE algorithm that adapts based on +problem characteristics including noise type and algorithm hints. + +## Keyword Arguments + + - `alg_hints::Vector{Symbol}=Symbol[]`: Algorithm hints like :additive, :commutative, :stiff, :stratonovich + - `autodiff::Union{Bool,ADTypes.AbstractADType}=true`: Automatic differentiation backend + - `kwargs...`: Additional keyword arguments passed to the algorithms +""" +function DefaultAdaptiveSDEAlgorithm(; alg_hints = Symbol[], + autodiff = true, kwargs...) + is_stiff = :stiff ∈ alg_hints + is_stratonovich = :stratonovich ∈ alg_hints + is_additive = :additive ∈ alg_hints + is_commutative = :commutative ∈ alg_hints + + # Select appropriate algorithms based on hints + if is_additive + nonstiff = SOSRA() + stiff = ISSEM(autodiff = autodiff) # Use ISSEM for stiff additive noise + elseif is_commutative + nonstiff = RKMilCommute() + stiff = ImplicitRKMil(autodiff = autodiff) + else + nonstiff = SOSRI() + if is_stratonovich + stiff = ImplicitRKMil(autodiff = autodiff, + interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich) + else + stiff = ImplicitRKMil(autodiff = autodiff) + end + end + + AutoAlgSwitch(nonstiff, stiff; stiffalgfirst = is_stiff) +end + +""" + default_sde_alg_choice(prob::SDEProblem) + +Determines the default algorithm choice for an SDE problem based on its characteristics. +Returns algorithm hints that can be used to select appropriate solvers. +""" +function default_sde_alg_choice(prob::SDEProblem) + alg_hints = Symbol[] + + # Check for mass matrix (implies stiffness) + if hasproperty(prob.f, :mass_matrix) && prob.f.mass_matrix !== I + push!(alg_hints, :stiff) + end + + # Check for additive noise (diagonal noise with no state dependence) + if prob.noise === nothing || isa(prob.noise, DiffEqNoiseProcess.NoiseProcess) + if DiffEqBase.is_diagonal_noise(prob) + # Simple heuristic for additive noise + push!(alg_hints, :additive) + end + end + + # Check problem interpretation + if hasproperty(prob, :interpretation) && + prob.interpretation == SciMLBase.AlgorithmInterpretation.Stratonovich + push!(alg_hints, :stratonovich) + end + + return alg_hints +end + +# Hook into the solve interface +function SciMLBase.__init(prob::SDEProblem, ::Nothing, args...; kwargs...) + alg_hints = default_sde_alg_choice(prob) + alg = if isempty(alg_hints) + DefaultSDEAlgorithm(; kwargs...) + else + DefaultAdaptiveSDEAlgorithm(; alg_hints = alg_hints, kwargs...) + end + SciMLBase.__init(prob, alg, args...; kwargs...) +end + +function SciMLBase.__solve(prob::SDEProblem, ::Nothing, args...; kwargs...) + alg_hints = default_sde_alg_choice(prob) + alg = if isempty(alg_hints) + DefaultSDEAlgorithm(; kwargs...) + else + DefaultAdaptiveSDEAlgorithm(; alg_hints = alg_hints, kwargs...) + end + SciMLBase.__solve(prob, alg, args...; kwargs...) +end + +# Mark default algorithms for special handling +function isdefaultalg(alg::StochasticCompositeAlgorithm) + # Check if it's one of our default algorithm combinations + if isa(alg.algs, Tuple{SOSRI, ImplicitRKMil}) || + isa(alg.algs, Tuple{SOSRA, ISSEM}) || + isa(alg.algs, Tuple{RKMilCommute, ImplicitRKMil}) + return true + end + return false +end + +# Also provide a fallback for other algorithms +isdefaultalg(alg) = false + +export DefaultSDEAlgorithm, DefaultAdaptiveSDEAlgorithm, default_sde_alg_choice, + isdefaultalg diff --git a/test/default_alg_test.jl b/test/default_alg_test.jl new file mode 100644 index 00000000..b74906bd --- /dev/null +++ b/test/default_alg_test.jl @@ -0,0 +1,128 @@ +using StochasticDiffEq, DiffEqNoiseProcess, Random, LinearAlgebra +using Test + +Random.seed!(100) + +# Simple SDE problem for testing +f(u, p, t) = 1.01 * u +g(u, p, t) = 0.87 * u +u0 = 1.0 +tspan = (0.0, 1.0) + +@testset "Default SDE Algorithm Tests" begin + @testset "Basic DefaultSDEAlgorithm" begin + prob = SDEProblem(f, g, u0, tspan) + + # Test with explicit DefaultSDEAlgorithm + alg = DefaultSDEAlgorithm() + sol = solve(prob, alg) + @test sol.retcode == ReturnCode.Success + @test length(sol.t) > 2 + + # Test with nothing (should use default) + sol2 = solve(prob, nothing) + @test sol2.retcode == ReturnCode.Success + @test length(sol2.t) > 2 + + # Test solve without algorithm (should use default) + sol3 = solve(prob) + @test sol3.retcode == ReturnCode.Success + @test length(sol3.t) > 2 + end + + @testset "DefaultAdaptiveSDEAlgorithm with hints" begin + prob = SDEProblem(f, g, u0, tspan) + + # Test with additive noise hint + alg = DefaultAdaptiveSDEAlgorithm(alg_hints = [:additive]) + sol = solve(prob, alg) + @test sol.retcode == ReturnCode.Success + + # Test with stiff hint + alg = DefaultAdaptiveSDEAlgorithm(alg_hints = [:stiff]) + sol = solve(prob, alg) + @test sol.retcode == ReturnCode.Success + + # Test with commutative hint + alg = DefaultAdaptiveSDEAlgorithm(alg_hints = [:commutative]) + sol = solve(prob, alg) + @test sol.retcode == ReturnCode.Success + end + + @testset "Additive noise problems" begin + # Additive noise SDE + f_add(u, p, t) = -0.5 * u + g_add(u, p, t) = 0.1 + prob_add = SDEProblem(f_add, g_add, u0, tspan) + + sol = solve(prob_add, nothing) + @test sol.retcode == ReturnCode.Success + @test length(sol.t) > 2 + end + + @testset "System of SDEs" begin + # System of SDEs + f_sys(du, u, p, t) = begin + du[1] = 1.01 * u[1] + du[2] = -0.5 * u[2] + end + g_sys(du, u, p, t) = begin + du[1] = 0.3 * u[1] + du[2] = 0.2 * u[2] + end + u0_sys = [1.0, 2.0] + prob_sys = SDEProblem(f_sys, g_sys, u0_sys, tspan) + + sol = solve(prob_sys, nothing) + @test sol.retcode == ReturnCode.Success + @test length(sol.t) > 2 + @test size(sol.u[end]) == size(u0_sys) + end + + @testset "Stiff SDE problems" begin + # Stiff problem + f_stiff(u, p, t) = -100.0 * u + g_stiff(u, p, t) = 0.1 * u + prob_stiff = SDEProblem(f_stiff, g_stiff, u0, (0.0, 0.1)) + + alg = DefaultAdaptiveSDEAlgorithm(alg_hints = [:stiff]) + sol = solve(prob_stiff, alg) + @test sol.retcode == ReturnCode.Success + end + + @testset "Algorithm selection logic" begin + prob = SDEProblem(f, g, u0, tspan) + + # Test default_sde_alg_choice + hints = default_sde_alg_choice(prob) + @test isa(hints, Vector{Symbol}) + + # Test with mass matrix (should add :stiff hint) + function f_mass(du, u, p, t) + du[1] = 1.01 * u[1] + du[2] = -0.5 * u[2] + end + function g_mass(du, u, p, t) + du[1] = 0.3 * u[1] + du[2] = 0.2 * u[2] + end + M = [1.0 0.0; 0.0 2.0] + f_mass_func = ODEFunction(f_mass, mass_matrix = M) + prob_mass = SDEProblem(f_mass_func, g_mass, [1.0, 2.0], tspan) + + hints_mass = default_sde_alg_choice(prob_mass) + @test :stiff in hints_mass + end + + @testset "isdefaultalg check" begin + alg1 = DefaultSDEAlgorithm() + @test StochasticDiffEq.isdefaultalg(alg1) == true + + alg2 = DefaultAdaptiveSDEAlgorithm(alg_hints = [:additive]) + @test StochasticDiffEq.isdefaultalg(alg2) == true + + # Non-default algorithm + alg3 = EM() + @test StochasticDiffEq.isdefaultalg(alg3) == false + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ebfecba1..b20a8f07 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -91,6 +91,9 @@ const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR") @time @safetestset "Composite Tests" begin include("composite_algorithm_test.jl") end + @time @safetestset "Default Algorithm Tests" begin + include("default_alg_test.jl") + end @time @safetestset "Events Tests" begin include("events_test.jl") end