Skip to content

Commit bba3b55

Browse files
Merge pull request #447 from SciML/standardtag
add standardtag mechanism to StochasticDiffEq implicit methods
2 parents 803b641 + 6a22bc2 commit bba3b55

File tree

5 files changed

+43
-15
lines changed

5 files changed

+43
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ FiniteDiff = "2"
4141
ForwardDiff = "0.10.3"
4242
MuladdMacro = "0.2.1"
4343
NLsolve = "4"
44-
OrdinaryDiffEq = "5.64"
44+
OrdinaryDiffEq = "5.69"
4545
RandomNumbers = "1.5.3"
4646
RecursiveArrayTools = "2"
4747
Reexport = "0.2, 1.0"

src/alg_utils.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ alg_interpretation(alg::EulerHeun) = :Stratonovich
122122
alg_interpretation(alg::LambaEulerHeun) = :Stratonovich
123123
alg_interpretation(alg::KomBurSROCK2) = :Stratonovich
124124
alg_interpretation(alg::RKMil{interpretation}) where {interpretation} = interpretation
125-
alg_interpretation(alg::SROCK1{interpretation}) where {interpretation} = interpretation
125+
alg_interpretation(alg::SROCK1{interpretation,E}) where {interpretation,E} = interpretation
126126
alg_interpretation(alg::RKMilCommute{interpretation}) where {interpretation} = interpretation
127127
alg_interpretation(alg::RKMilGeneral) = alg.interpretation
128-
alg_interpretation(alg::ImplicitRKMil{CS,AD,F,S,N,T2,Controller,interpretation}) where {CS,AD,F,S,N,T2,Controller,interpretation} = interpretation
128+
alg_interpretation(alg::ImplicitRKMil{CS,AD,F,FDT,ST,N,T2,Controller,interpretation}) where {CS,AD,F,FDT,ST,N,T2,Controller,interpretation} = interpretation
129129

130130
alg_interpretation(alg::RS1) = :Stratonovich
131131
alg_interpretation(alg::RS2) = :Stratonovich
@@ -256,9 +256,15 @@ OrdinaryDiffEq.get_chunksize(alg::StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,
256256
OrdinaryDiffEq.get_chunksize(alg::StochasticDiffEqJumpNewtonAdaptiveAlgorithm{CS,AD,Controller}) where {CS,AD,Controller} = Val(CS)
257257
OrdinaryDiffEq.get_chunksize(alg::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{CS,AD,Controller}) where {CS,AD,Controller} = Val(CS)
258258

259+
@static if isdefined(OrdinaryDiffEq,:standardtag)
260+
OrdinaryDiffEq.standardtag(alg::Union{StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller},
261+
StochasticDiffEqNewtonAlgorithm{CS,AD,FDT,ST,Controller}}
262+
) where {CS,AD,FDT,ST,Controller} = ST
263+
end
264+
259265
@static if isdefined(OrdinaryDiffEq,:alg_difftype)
260-
OrdinaryDiffEq.alg_difftype(alg::Union{StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,Controller},
261-
StochasticDiffEqNewtonAlgorithm{CS,AD,FDT,Controller}}) where {CS,AD,FDT,Controller} = FDT
266+
OrdinaryDiffEq.alg_difftype(alg::Union{StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller},
267+
StochasticDiffEqNewtonAlgorithm{CS,AD,FDT,ST,Controller}}) where {CS,AD,FDT,ST,Controller} = FDT
262268
end
263269

264270
alg_mass_matrix_compatible(alg::StochasticDiffEqAlgorithm) = false

src/algorithms.jl

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@ abstract type StochasticDiffEqRODEAlgorithm <: AbstractRODEAlgorithm end
66
abstract type StochasticDiffEqRODEAdaptiveAlgorithm <: StochasticDiffEqRODEAlgorithm end
77
abstract type StochasticDiffEqRODECompositeAlgorithm <: StochasticDiffEqRODEAlgorithm end
88

9-
abstract type StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,Controller} <: StochasticDiffEqAdaptiveAlgorithm end
10-
abstract type StochasticDiffEqNewtonAlgorithm{CS,AD,FDT,Controller} <: StochasticDiffEqAlgorithm end
9+
abstract type StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller} <: StochasticDiffEqAdaptiveAlgorithm end
10+
abstract type StochasticDiffEqNewtonAlgorithm{CS,AD,FDT,ST,Controller} <: StochasticDiffEqAlgorithm end
1111

1212
abstract type StochasticDiffEqJumpAlgorithm <: StochasticDiffEqAlgorithm end
1313
abstract type StochasticDiffEqJumpAdaptiveAlgorithm <: StochasticDiffEqAlgorithm end
14-
abstract type StochasticDiffEqJumpNewtonAdaptiveAlgorithm{CS,AD,FDT,Controller} <: StochasticDiffEqJumpAdaptiveAlgorithm end
14+
abstract type StochasticDiffEqJumpNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller} <: StochasticDiffEqJumpAdaptiveAlgorithm end
1515

1616
abstract type StochasticDiffEqJumpDiffusionAlgorithm <: StochasticDiffEqAlgorithm end
1717
abstract type StochasticDiffEqJumpDiffusionAdaptiveAlgorithm <: StochasticDiffEqAlgorithm end
18-
abstract type StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{CS,AD,FDT,Controller} <: StochasticDiffEqJumpDiffusionAdaptiveAlgorithm end
18+
abstract type StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{CS,AD,FDT,ST,Controller} <: StochasticDiffEqJumpDiffusionAdaptiveAlgorithm end
1919

2020
abstract type IteratedIntegralApprox end
2121
struct IICommutative <: IteratedIntegralApprox end
@@ -625,7 +625,7 @@ This is a theta method which defaults to theta=1 or the Trapezoid method on the
625625
This method defaults to symplectic=false, but when true and theta=1/2 this is the implicit Midpoint method on the drift term and is symplectic in distribution.
626626
Can handle all forms of noise, including non-diagonal, scalar, and colored noise. Uses a 1.0/1.5 heuristic for adaptive time stepping.
627627
"""
628-
struct ImplicitEM{CS,AD,F,F2,FDT,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,Controller}
628+
struct ImplicitEM{CS,AD,F,F2,FDT,ST,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller}
629629
linsolve::F
630630
nlsolve::F2
631631
theta::T2
@@ -634,13 +634,15 @@ struct ImplicitEM{CS,AD,F,F2,FDT,T2,Controller} <: StochasticDiffEqNewtonAdaptiv
634634
symplectic::Bool
635635
end
636636
ImplicitEM(;chunk_size=0,autodiff=true,diff_type=Val{:central},
637+
standardtag = Val{true}(),
637638
linsolve=DEFAULT_LINSOLVE,nlsolve=NLNewton(),
638639
extrapolant=:constant,
639640
theta = 1,symplectic=false,
640641
new_jac_conv_bound = 1e-3,
641642
controller = :Predictive) =
642643
ImplicitEM{chunk_size,autodiff,
643644
typeof(linsolve),typeof(nlsolve),diff_type,
645+
OrdinaryDiffEq._unwrap_val(standardtag),
644646
typeof(new_jac_conv_bound),controller}(
645647
linsolve,nlsolve,
646648
symplectic ? 1/2 : theta,
@@ -655,7 +657,7 @@ This is a theta method which defaults to theta=1/2 or the Trapezoid method on th
655657
This method defaults to symplectic=false, but when true and theta=1 this is the implicit Midpoint method on the drift term and is symplectic in distribution.
656658
Can handle all forms of noise, including non-diagonal, scalar, and colored noise. Uses a 1.0/1.5 heuristic for adaptive time stepping.
657659
"""
658-
struct ImplicitEulerHeun{CS,AD,F,FDT,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,Controller}
660+
struct ImplicitEulerHeun{CS,AD,F,FDT,ST,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller}
659661
linsolve::F
660662
nlsolve::N
661663
theta::T2
@@ -664,13 +666,15 @@ struct ImplicitEulerHeun{CS,AD,F,FDT,N,T2,Controller} <: StochasticDiffEqNewtonA
664666
symplectic::Bool
665667
end
666668
ImplicitEulerHeun(;chunk_size=0,autodiff=true,diff_type=Val{:central},
669+
standardtag = Val{true}(),
667670
linsolve=DEFAULT_LINSOLVE,nlsolve=NLNewton(),
668671
extrapolant=:constant,
669672
theta = 1,symplectic = false,
670673
new_jac_conv_bound = 1e-3,
671674
controller = :Predictive) =
672675
ImplicitEulerHeun{chunk_size,autodiff,
673676
typeof(linsolve),diff_type,
677+
OrdinaryDiffEq._unwrap_val(standardtag),
674678
typeof(nlsolve),
675679
typeof(new_jac_conv_bound),controller}(
676680
linsolve,nlsolve,
@@ -686,7 +690,7 @@ Defaults to solving the Ito problem, but ImplicitRKMil(interpretation=:Stratonov
686690
This method defaults to symplectic=false, but when true and theta=1/2 this is the implicit Midpoint method on the drift term and is symplectic in distribution.
687691
Handles diagonal and scalar noise. Uses a 1.5/2.0 heuristic for adaptive time stepping.
688692
"""
689-
struct ImplicitRKMil{CS,AD,F,FDT,N,T2,Controller,interpretation} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,Controller}
693+
struct ImplicitRKMil{CS,AD,F,FDT,ST,N,T2,Controller,interpretation} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller}
690694
linsolve::F
691695
nlsolve::N
692696
theta::T2
@@ -695,13 +699,15 @@ struct ImplicitRKMil{CS,AD,F,FDT,N,T2,Controller,interpretation} <: StochasticDi
695699
symplectic::Bool
696700
end
697701
ImplicitRKMil(;chunk_size=0,autodiff=true,diff_type=Val{:central},
702+
standardtag = Val{true}(),
698703
linsolve=DEFAULT_LINSOLVE,nlsolve=NLNewton(),
699704
extrapolant=:constant,
700705
theta = 1,symplectic = false,
701706
new_jac_conv_bound = 1e-3,
702707
controller = :Predictive,interpretation=:Ito) =
703708
ImplicitRKMil{chunk_size,autodiff,
704709
typeof(linsolve),diff_type,
710+
OrdinaryDiffEq._unwrap_val(standardtag),
705711
typeof(nlsolve),typeof(new_jac_conv_bound),
706712
controller,interpretation}(
707713
linsolve,nlsolve,
@@ -716,7 +722,7 @@ This is a theta method which defaults to theta=1 or the Trapezoid method on the
716722
This method defaults to symplectic=false, but when true and theta=1/2 this is the implicit Midpoint method on the drift term and is symplectic in distribution.
717723
Can handle all forms of noise, including non-diagonal, scalar, and colored noise. Uses a 1.0/1.5 heuristic for adaptive time stepping.
718724
"""
719-
struct ISSEM{CS,AD,F,FDT,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,Controller}
725+
struct ISSEM{CS,AD,F,FDT,ST,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller}
720726
linsolve::F
721727
nlsolve::N
722728
theta::T2
@@ -725,13 +731,15 @@ struct ISSEM{CS,AD,F,FDT,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgor
725731
symplectic::Bool
726732
end
727733
ISSEM(;chunk_size=0,autodiff=true,diff_type=Val{:central},
734+
standardtag = Val{true}(),
728735
linsolve=DEFAULT_LINSOLVE,nlsolve=NLNewton(),
729736
extrapolant=:constant,
730737
theta = 1,symplectic=false,
731738
new_jac_conv_bound = 1e-3,
732739
controller = :Predictive) =
733740
ISSEM{chunk_size,autodiff,
734741
typeof(linsolve),diff_type,
742+
OrdinaryDiffEq._unwrap_val(standardtag),
735743
typeof(nlsolve),
736744
typeof(new_jac_conv_bound),controller}(
737745
linsolve,nlsolve,
@@ -746,7 +754,7 @@ This is a theta method which defaults to theta=1 or the Trapezoid method on the
746754
This method defaults to symplectic=false, but when true and theta=1/2 this is the implicit Midpoint method on the drift term and is symplectic in distribution.
747755
Can handle all forms of noise, including non-diagonal,Q scalar, and colored noise. Uses a 1.0/1.5 heuristic for adaptive time stepping.
748756
"""
749-
struct ISSEulerHeun{CS,AD,F,FDT,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,Controller}
757+
struct ISSEulerHeun{CS,AD,F,FDT,ST,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller}
750758
linsolve::F
751759
nlsolve::N
752760
theta::T2
@@ -755,13 +763,15 @@ struct ISSEulerHeun{CS,AD,F,FDT,N,T2,Controller} <: StochasticDiffEqNewtonAdapti
755763
symplectic::Bool
756764
end
757765
ISSEulerHeun(;chunk_size=0,autodiff=true,diff_type=Val{:central},
766+
standardtag = Val{true}(),
758767
linsolve=DEFAULT_LINSOLVE,nlsolve=NLNewton(),
759768
extrapolant=:constant,
760769
theta = 1,symplectic=false,
761770
new_jac_conv_bound = 1e-3,
762771
controller = :Predictive) =
763772
ISSEulerHeun{chunk_size,autodiff,
764773
typeof(linsolve),diff_type,
774+
OrdinaryDiffEq._unwrap_val(standardtag),
765775
typeof(nlsolve),typeof(new_jac_conv_bound),controller}(
766776
linsolve,nlsolve,
767777
symplectic ? 1/2 : theta,
@@ -772,7 +782,7 @@ SKenCarp: Stiff Method
772782
Adaptive L-stable drift-implicit strong order 1.5 for additive Ito and Stratonovich SDEs with weak order 2.
773783
Can handle diagonal, non-diagonal and scalar additive noise.
774784
"""
775-
struct SKenCarp{CS,AD,F,FDT,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,Controller}
785+
struct SKenCarp{CS,AD,F,FDT,ST,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,Controller}
776786
linsolve::F
777787
nlsolve::N
778788
smooth_est::Bool
@@ -782,11 +792,13 @@ struct SKenCarp{CS,AD,F,FDT,N,T2,Controller} <: StochasticDiffEqNewtonAdaptiveAl
782792
end
783793

784794
SKenCarp(;chunk_size=0,autodiff=true,diff_type=Val{:central},
795+
standardtag = Val{true}(),
785796
linsolve=DEFAULT_LINSOLVE,nlsolve=NLNewton(),
786797
smooth_est=true,extrapolant=:min_correct,
787798
new_jac_conv_bound = 1e-3,controller = :Predictive,
788799
ode_error_est = true) =
789800
SKenCarp{chunk_size,autodiff,typeof(linsolve),diff_type,
801+
OrdinaryDiffEq._unwrap_val(standardtag),
790802
typeof(nlsolve),typeof(new_jac_conv_bound),controller}(
791803
linsolve,nlsolve,smooth_est,extrapolant,new_jac_conv_bound,
792804
ode_error_est)

src/perform_step/low_order.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ end
179179
utilde = uprev + L*integrator.sqdt
180180
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
181181
mil_correction = ggprime.*(W.dW.^2)./2
182+
else
183+
error("Alg interpretation invalid. Use either :Ito or :Stratonovich")
182184
end
183185
u = K+L.*W.dW+mil_correction
184186
if integrator.opts.adaptive
@@ -207,6 +209,8 @@ end
207209
@.. tmp = (du2-L)/(2integrator.sqdt)*(W.dW^2 - dt)
208210
elseif alg_interpretation(integrator.alg) == :Stratonovich
209211
@.. tmp = (du2-L)/(2integrator.sqdt)*(W.dW^2)
212+
else
213+
error("Alg interpretation invalid. Use either :Ito or :Stratonovich")
210214
end
211215
@.. u = K+L*W.dW + tmp
212216
if integrator.opts.adaptive
@@ -232,6 +236,8 @@ end
232236
@.. tmp = uprev + integrator.sqdt * L
233237
integrator.g(du2,tmp,p,t)
234238
@.. tmp = (du2-L)/(2integrator.sqdt)*(W.dW.^2)
239+
else
240+
error("Alg interpretation invalid. Use either :Ito or :Stratonovich")
235241
end
236242
@.. u = K+L*W.dW + tmp
237243
if integrator.opts.adaptive

src/perform_step/sdirk.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
3838
mil_correction = ggprime.*(integrator.W.dW.^2)./2
3939
gtmp += mil_correction
40+
else
41+
error("Alg interpretation invalid. Use either :Ito or :Stratonovich")
4042
end
4143
end
4244

@@ -162,6 +164,8 @@ end
162164
integrator.g(gtmp3,z,p,t)
163165
@.. gtmp3 = (gtmp3-gtmp)/(integrator.sqdt) # ggprime approximation
164166
@.. gtmp2 += gtmp3*(dW.^2)/2
167+
else
168+
error("Alg interpretation invalid. Use either :Ito or :Stratonovich")
165169
end
166170
end
167171

0 commit comments

Comments
 (0)