Skip to content

Commit 5bec928

Browse files
Merge pull request #459 from SciML/levy_area
Switch to LevyArea.jl
2 parents c802f6e + e450613 commit 5bec928

File tree

10 files changed

+262
-798
lines changed

10 files changed

+262
-798
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ steps:
4141
agents:
4242
os: "linux"
4343
queue: "juliaecosystem"
44+
exclusive: true
4445
env:
4546
GROUP: 'SROCKC2WeakConvergence'
4647
timeout_in_minutes: 240

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1414
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1515
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1616
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
17+
LevyArea = "2d8b4e74-eb68-11e8-0fb9-d5eb67b50637"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1920
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
@@ -39,6 +40,7 @@ DocStringExtensions = "0.8"
3940
FillArrays = "0.6, 0.7, 0.8, 0.9, 0.10, 0.11, 0.12"
4041
FiniteDiff = "2"
4142
ForwardDiff = "0.10.3"
43+
LevyArea = "1.0.0"
4244
MuladdMacro = "0.2.1"
4345
NLsolve = "4"
4446
OrdinaryDiffEq = "6.4"

src/StochasticDiffEq.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ using DocStringExtensions
6363

6464
using SparseDiffTools: forwarddiff_color_jacobian!, ForwardColorJacCache
6565

66+
using LevyArea
6667

6768
const CompiledFloats = Union{Float32,Float64}
6869

@@ -164,7 +165,7 @@ using DocStringExtensions
164165

165166
export RandomEM
166167

167-
export IteratedIntegralApprox, IICommutative, IIWiktorsson
168+
export IteratedIntegralApprox, IICommutative, IILevyArea
168169

169170
#General Functions
170171
export solve, init, solve!, step!

src/alg_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ alg_interpretation(alg::LambaEulerHeun) = :Stratonovich
123123
alg_interpretation(alg::KomBurSROCK2) = :Stratonovich
124124
alg_interpretation(alg::RKMil{interpretation}) where {interpretation} = interpretation
125125
alg_interpretation(alg::SROCK1{interpretation,E}) where {interpretation,E} = interpretation
126-
alg_interpretation(alg::RKMilCommute{interpretation}) where {interpretation} = interpretation
126+
alg_interpretation(alg::RKMilCommute) = alg.interpretation
127127
alg_interpretation(alg::RKMilGeneral) = alg.interpretation
128128
alg_interpretation(alg::ImplicitRKMil{CS,AD,F,P,FDT,ST,CJ,N,T2,Controller,interpretation}) where {CS,AD,F,P,FDT,ST,CJ,N,T2,Controller,interpretation} = interpretation
129129

src/algorithms.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ abstract type StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{CS,AD,FDT,ST,
1919

2020
abstract type IteratedIntegralApprox end
2121
struct IICommutative <: IteratedIntegralApprox end
22-
struct IIWiktorsson <: IteratedIntegralApprox end
22+
struct IILevyArea <: IteratedIntegralApprox end
2323

2424
################################################################################
2525

@@ -91,31 +91,40 @@ RKMilCommute: Nonstiff Method
9191
An explicit Runge-Kutta discretization of the strong order 1.0 Milstein method for commutative noise problems.
9292
Defaults to solving the Ito problem, but RKMilCommute(interpretation=:Stratonovich) makes it solve the Stratonovich problem.
9393
Uses a 1.5/2.0 error estimate for adaptive time stepping.
94+
Default: ii_approx=IICommutative() does not approximate the Levy area.
9495
"""
95-
struct RKMilCommute{interpretation} <: StochasticDiffEqAdaptiveAlgorithm end
96-
RKMilCommute(;interpretation=:Ito) = RKMilCommute{interpretation}()
96+
struct RKMilCommute{T} <: StochasticDiffEqAdaptiveAlgorithm
97+
interpretation::Symbol
98+
ii_approx::T
99+
end
100+
RKMilCommute(;interpretation=:Ito, ii_approx=IICommutative()) = RKMilCommute(interpretation,ii_approx)
97101

98102
"""
99103
Kloeden, P.E., Platen, E., Numerical Solution of Stochastic Differential Equations.
100104
Springer. Berlin Heidelberg (2011)
101105
102106
RKMilGeneral: Nonstiff Method
103-
RKMilGeneral(;interpretation=:Ito, ii_approx=IIWiktorsson()
107+
RKMilGeneral(;interpretation=:Ito, ii_approx=IILevyArea()
104108
An explicit Runge-Kutta discretization of the strong order 1.0 Milstein method for general non-commutative noise problems.
105109
Allows for a choice of interpretation between :Ito and :Stratonovich.
106110
Allows for a choice of iterated integral approximation.
111+
Default: ii_approx=IILevyArea() uses LevyArea.jl to choose optimal algorithm. See
112+
Kastner, F. and Rößler, A., arXiv: 2201.08424
113+
Kastner, F. and Rößler, A., LevyArea.jl, 10.5281/ZENODO.5883748, https://github.com/stochastics-uni-luebeck/LevyArea.jl
107114
"""
108-
struct RKMilGeneral{T<:IteratedIntegralApprox, TruncationType} <: StochasticDiffEqAdaptiveAlgorithm
115+
struct RKMilGeneral{T, TruncationType} <: StochasticDiffEqAdaptiveAlgorithm
109116
interpretation::Symbol
110117
ii_approx::T
111118
c::Int
112119
p::TruncationType
113120
end
114-
function RKMilGeneral(;interpretation=:Ito, ii_approx=IIWiktorsson(), c=1, p=nothing, dt=nothing)
121+
122+
function RKMilGeneral(;interpretation=:Ito,ii_approx=IILevyArea(), c=1, p=nothing, dt=nothing)
115123
γ = 1//1
116124
p==true && (p = Int(floor(c*dt^(1//1-2//1*γ)) + 1))
117-
RKMilGeneral(interpretation, ii_approx, c, p)
125+
RKMilGeneral{typeof(ii_approx), typeof(p)}(interpretation, ii_approx, c, p)
118126
end
127+
119128
"""
120129
WangLi3SMil_A: Nonstiff Method
121130
Fixed step-size explicit 3-stage Milstein methods for Ito problem with strong and weak order 1.0

src/caches/basic_method_caches.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -115,43 +115,44 @@ function alg_cache(alg::RKMil,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototy
115115
RKMilCache(u,uprev,du1,du2,K,tmp,L)
116116
end
117117

118-
struct RKMilCommuteConstantCache{WikType} <: StochasticDiffEqConstantCache
119-
WikJ::WikType
118+
struct RKMilCommuteConstantCache{JalgType} <: StochasticDiffEqConstantCache
119+
Jalg::JalgType
120120
end
121-
@cache struct RKMilCommuteCache{uType,rateType,rateNoiseType,WikType} <: StochasticDiffEqMutableCache
121+
@cache struct RKMilCommuteCache{uType,rateType,rateNoiseType,JalgType} <: StochasticDiffEqMutableCache
122122
u::uType
123123
uprev::uType
124124
du1::rateType
125125
du2::rateType
126126
K::rateType
127127
gtmp::rateNoiseType
128128
L::rateNoiseType
129-
WikJ::WikType
129+
Jalg::JalgType
130130
mil_correction::rateType
131131
Kj::uType
132132
Dgj::rateNoiseType
133133
tmp::uType
134134
end
135135

136-
function alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} WikJ = WikJ = get_WikJ(ΔW,prob,alg)
137-
RKMilCommuteConstantCache{typeof(WikJ)}(WikJ)
136+
function alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
137+
Jalg = get_Jalg(ΔW,dt,prob,alg)
138+
RKMilCommuteConstantCache{typeof(Jalg)}(Jalg)
138139
end
139140

140141
function alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
141142
du1 = zero(rate_prototype); du2 = zero(rate_prototype)
142143
K = zero(rate_prototype); gtmp = zero(noise_rate_prototype);
143144
L = zero(noise_rate_prototype); tmp = zero(rate_prototype)
144-
WikJ = get_WikJ(ΔW,prob,alg)
145+
Jalg = get_Jalg(ΔW,dt,prob,alg)
145146
mil_correction = zero(rate_prototype)
146147
Kj = zero(u); Dgj = zero(noise_rate_prototype)
147-
RKMilCommuteCache(u,uprev,du1,du2,K,gtmp,L,WikJ,mil_correction,Kj,Dgj,tmp)
148+
RKMilCommuteCache(u,uprev,du1,du2,K,gtmp,L,Jalg,mil_correction,Kj,Dgj,tmp)
148149
end
149150

150-
struct RKMilGeneralConstantCache{WikType} <: StochasticDiffEqConstantCache
151-
WikJ::WikType
151+
struct RKMilGeneralConstantCache{JalgType} <: StochasticDiffEqConstantCache
152+
Jalg::JalgType
152153
end
153154

154-
@cache struct RKMilGeneralCache{uType, rateType, rateNoiseType, WikType} <: StochasticDiffEqMutableCache
155+
@cache struct RKMilGeneralCache{uType, rateType, rateNoiseType, JalgType} <: StochasticDiffEqMutableCache
155156
u::uType
156157
uprev::uType
157158
tmp::uType
@@ -161,12 +162,12 @@ end
161162
L::rateNoiseType
162163
mil_correction::uType
163164
ggprime::rateNoiseType
164-
WikJ::WikType
165+
Jalg::JalgType
165166
end
166167

167-
function alg_cache(alg::RKMilGeneral,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
168-
WikJ = get_WikJ(ΔW,prob,alg)
169-
RKMilGeneralConstantCache{typeof(WikJ)}(WikJ)
168+
function alg_cache(alg::RKMilGeneral, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
169+
Jalg = get_Jalg(ΔW, dt, prob, alg)
170+
RKMilGeneralConstantCache{typeof(Jalg)}(Jalg)
170171
end
171172

172173
function alg_cache(alg::RKMilGeneral,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
@@ -177,6 +178,6 @@ function alg_cache(alg::RKMilGeneral,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_
177178
L = zero(noise_rate_prototype)
178179
mil_correction = zero(u)
179180
ggprime = zero(noise_rate_prototype)
180-
WikJ = get_WikJ(ΔW,prob,alg)
181-
RKMilGeneralCache{typeof(u), typeof(rate_prototype), typeof(noise_rate_prototype), typeof(WikJ)}(u, uprev, tmp, du₁, du₂, K, L, mil_correction, ggprime, WikJ)
181+
Jalg = get_Jalg(ΔW,dt,prob,alg)
182+
RKMilGeneralCache{typeof(u), typeof(rate_prototype), typeof(noise_rate_prototype), typeof(Jalg)}(u, uprev, tmp, du₁, du₂, K, L, mil_correction, ggprime, Jalg)
182183
end

0 commit comments

Comments
 (0)