Skip to content

Commit 32ceae8

Browse files
committed
first steps new rode solver
1 parent 2fa86b9 commit 32ceae8

File tree

7 files changed

+173
-1
lines changed

7 files changed

+173
-1
lines changed

src/StochasticDiffEq.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ using DocStringExtensions
166166
StochasticDiffEqRODECompositeAlgorithm
167167

168168
export RandomEM
169+
export RandomHeun
169170

170171
export IteratedIntegralApprox, IICommutative, IILevyArea
171172

src/algorithms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,8 @@ end
848848

849849
struct RandomEM <: StochasticDiffEqRODEAlgorithm end
850850

851+
struct RandomHeun <: StochasticDiffEqRODEAlgorithm end
852+
851853
const SplitSDEAlgorithms = Union{IIF1M,IIF2M,IIF1Mil,SKenCarp,SplitEM}
852854

853855
@doc raw"""

src/caches/basic_method_caches.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,22 @@ function alg_cache(alg::RandomEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prot
7272
RandomEMCache(u,uprev,tmp,rtmp)
7373
end
7474

75+
struct RandomHeunConstantCache <: StochasticDiffEqConstantCache end
76+
@cache struct RandomHeunCache{uType,rateType} <: StochasticDiffEqMutableCache
77+
u::uType
78+
uprev::uType
79+
tmp::uType
80+
rtmp1::rateType
81+
rtmp2::rateType
82+
end
83+
84+
alg_cache(alg::RandomHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = RandomHeunConstantCache()
85+
86+
function alg_cache(alg::RandomHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
87+
tmp = zero(u); rtmp1 = zero(rate_prototype); rtmp2 = zero(rate_prototype)
88+
RandomHeunCache(u,uprev,tmp,rtmp1,rtmp2)
89+
end
90+
7591
struct SimplifiedEMConstantCache <: StochasticDiffEqConstantCache end
7692
@cache struct SimplifiedEMCache{randType,uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
7793
u::uType

src/perform_step/low_order.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,21 @@ end
121121
@.. u = uprev + dt * rtmp
122122
end
123123

124+
@muladd function perform_step!(integrator,cache::RandomHeunConstantCache)
125+
@unpack t,dt,uprev,u,W,p,f = integrator
126+
ftmp = integrator.f(uprev,p,t,W.curW)
127+
tmp = @.. uprev + dt * ftmp
128+
u = uprev .+ (dt / 2) .* (ftmp .+ integrator.f(tmp,p,t+dt, W.curW .+ W.dW)) # Need to check the last argument, if it should also be in terms of an intermediate step
129+
integrator.u = u
130+
end
131+
132+
@muladd function perform_step!(integrator,cache::RandomHeunCache)
133+
@unpack rtmp1, rtmp2 = cache
134+
@unpack t,dt,uprev,u,W,p,f = integrator
135+
integrator.f(rtmp1,uprev,p,t,W.curW)
136+
@.. u = uprev + dt * rtmp1
137+
end
138+
124139
# weak approximation EM
125140
@muladd function perform_step!(integrator,cache::SimplifiedEMConstantCache)
126141
@unpack t,dt,uprev,u,W,p,f = integrator

test/rode_linear_tests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@ u0 = 1.00
55
tspan = (0.0,1.0)
66
prob = RODEProblem(f,u0,tspan)
77
sol = solve(prob,RandomEM(),dt=1/100)
8+
sol2 = solve(prob,RandomHeun(),dt=1/100)
89

910
f(u,p,t,W,du) = (du.=1.01u.+0.87u.*W)
1011
u0 = ones(4)
1112
prob = RODEProblem(f,u0,tspan)
1213
sol = solve(prob,RandomEM(),dt=1/100)
14+
sol2 = solve(prob,RandomHeun(),dt=1/100)
1315

1416
f(u,p,t,W) = 2u*sin(W)
1517
u0 = 1.00
1618
tspan = (0.0,5.0)
1719
prob = RODEProblem{false}(f,u0,tspan)
1820
sol = solve(prob,RandomEM(),dt=1/100)
21+
sol2 = solve(prob,RandomHeun(),dt=1/100)
1922

2023
function f(du,u,p,t,W)
2124
du[1] = 2u[1]*sin(W[1] - W[2])
@@ -25,6 +28,7 @@ u0 = [1.00;1.00]
2528
tspan = (0.0,5.0)
2629
prob = RODEProblem(f,u0,tspan)
2730
sol = solve(prob,RandomEM(),dt=1/100)
31+
sol2 = solve(prob,RandomHeun(),dt=1/100)
2832

2933
function f(du,u,p,t,W)
3034
du[1] = -2W[3]*u[1]*sin(W[1] - W[2])
@@ -34,3 +38,4 @@ u0 = [1.00;1.00]
3438
tspan = (0.0,5.0)
3539
prob = RODEProblem(f,u0,tspan,rand_prototype=zeros(3))
3640
sol = solve(prob,RandomEM(),dt=1/100)
41+
sol2 = solve(prob,RandomHeun(),dt=1/100)

test/runtests copy.jl

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
using SafeTestsets
2+
using Pkg
3+
4+
function activate_gpu_env()
5+
Pkg.activate("gpu")
6+
Pkg.develop(PackageSpec(path=dirname(@__DIR__)))
7+
Pkg.instantiate()
8+
end
9+
10+
const LONGER_TESTS = false
11+
12+
const GROUP = get(ENV, "GROUP", "All")
13+
14+
const is_APPVEYOR = Sys.iswindows() && haskey(ENV,"APPVEYOR")
15+
16+
@time begin
17+
if GROUP == "All" || GROUP == "Interface"
18+
@time @safetestset "First Rand Tests" begin include("first_rand_test.jl") end
19+
@time @safetestset "Inference Tests" begin include("inference_test.jl") end
20+
@time @safetestset "Linear RODE Tests" begin include("rode_linear_tests.jl") end
21+
@time @safetestset "Complex Number Tests" begin include("complex_tests.jl") end
22+
@time @safetestset "Static Array Tests" begin include("static_array_tests.jl") end
23+
@time @safetestset "Noise Type Tests" begin include("noise_type_test.jl") end
24+
@time @safetestset "Mass matrix tests" begin include("mass_matrix_tests.jl") end
25+
#@time @safetestset "Sparse Jacobian tests" begin include("sparsediff_tests.jl") end
26+
@time @safetestset "Outofplace Arrays Tests" begin include("outofplace_arrays.jl") end
27+
@time @safetestset "tdir Tests" begin include("tdir_tests.jl") end
28+
@time @safetestset "tstops Tests" begin include("tstops_tests.jl") end
29+
@time @safetestset "saveat Tests" begin include("saveat_tests.jl") end
30+
@time @safetestset "Oval2" begin include("oval2_test.jl") end
31+
end
32+
33+
if GROUP == "All" || GROUP == "Interface2"
34+
@time @safetestset "Basic Tau Leaping Tests" begin include("tau_leaping.jl") end
35+
@time @safetestset "Linear SDE Tests" begin include("sde/sde_linear_tests.jl") end
36+
@time @safetestset "Two-dimensional Linear SDE Tests" begin include("sde/sde_twodimlinear_tests.jl") end
37+
@time @safetestset "Element-wise Tolerances Tests" begin include("tolerances_tests.jl") end
38+
@time @safetestset "Zero'd Noise Tests" begin include("zerod_noise_test.jl") end
39+
@time @safetestset "Scalar Tests" begin include("scalar_noise.jl") end
40+
@time @safetestset "Stiffness Detection Test" begin include("stiffness_detection_test.jl") end
41+
@time @safetestset "Adaptive SDE Linear Tests" begin include("adaptive/sde_linearadaptive_tests.jl") end
42+
end
43+
44+
if GROUP == "All" || GROUP == "Interface3"
45+
@time @safetestset "Composite Tests" begin include("composite_algorithm_test.jl") end
46+
@time @safetestset "Events Tests" begin include("events_test.jl") end
47+
@time @safetestset "Cache Tests" begin include("cache_test.jl") end
48+
@time @safetestset "Adaptive Complex Mean Test" begin include("adaptive/sde_complex_adaptive_mean_test.jl") end
49+
@time @safetestset "Utility Tests" begin include("utility_tests.jl") end
50+
@time @safetestset "Non-diagonal SDE Tests" begin include("nondiagonal_tests.jl") end
51+
@time @safetestset "No Index Tests" begin include("noindex_tests.jl") end
52+
@time @safetestset "Multiple Dimension Linear Adaptive Test" begin include("adaptive/sde_twodimlinearadaptive_tests.jl") end
53+
@time @safetestset "Autostepsize Test" begin include("adaptive/sde_autostepsize_test.jl") end
54+
@time @safetestset "Additive Lorenz Attractor Test" begin include("adaptive/sde_lorenzattractor_tests.jl") end
55+
@time @safetestset "Stochastic iterated integrals" begin include("levy_areas.jl") end
56+
end
57+
58+
if !is_APPVEYOR && (GROUP == "All" || GROUP == "AlgConvergence")
59+
@time @safetestset "Convergence Tests" begin include("sde/sde_convergence_tests.jl") end
60+
@time @safetestset "Dynamical SDE Tests" begin include("sde/sde_dynamical.jl") end
61+
end
62+
63+
if !is_APPVEYOR && GROUP == "AlgConvergence2"
64+
@time @safetestset "IIF Convergence Tests" begin include("iif_methods.jl") end
65+
@time @safetestset "Commutative Noise Methods Tests" begin include("commutative_tests.jl") end
66+
@time @safetestset "Multivariate Geometric Tests" begin include("multivariate_geometric.jl") end
67+
end
68+
69+
if !is_APPVEYOR && GROUP == "AlgConvergence3"
70+
@time @safetestset "Rossler Order Tests" begin include("sde/sde_rosslerorder_tests.jl") end
71+
@time @safetestset "ODE Convergence Regression Tests" begin include("ode_convergence_regression.jl") end
72+
@time @safetestset "Additive SDE Tests" begin include("sde/sde_additive_tests.jl") end
73+
@time @safetestset "Split Tests" begin include("split_tests.jl") end
74+
@time @safetestset "Stratonovich Convergence Tests" begin include("stratonovich_convergence_tests.jl") end
75+
end
76+
77+
if !is_APPVEYOR && GROUP == "WeakConvergence1"
78+
@time @safetestset "Multidimensional IIP Weak Convergence Tests" begin include("weak_convergence/multidim_iip_weak.jl") end
79+
@time @safetestset "Platen's PL1WM weak second order" begin include("weak_convergence/PL1WM.jl") end
80+
end
81+
82+
if !is_APPVEYOR && GROUP == "WeakConvergence2"
83+
@time @safetestset "Roessler weak SRK Tests" begin include("weak_convergence/srk_weak_final.jl") end
84+
end
85+
86+
if !is_APPVEYOR && GROUP == "WeakConvergence3"
87+
@time @safetestset "Roessler weak SRK (non-diagonal) Tests" begin include("weak_convergence/srk_weak_final_non_diagonal.jl") end
88+
end
89+
90+
if !is_APPVEYOR && GROUP == "WeakConvergence4"
91+
@time @safetestset "Weak Stratonovich (non-diagonal) Tests" begin include("weak_convergence/weak_strat_non_diagonal.jl") end
92+
@time @safetestset "SIE SME weak Tests" begin include("weak_convergence/SIE_SME.jl") end
93+
end
94+
95+
if !is_APPVEYOR && GROUP == "WeakConvergence5"
96+
@time @safetestset "Weak Stratonovich Tests" begin include("weak_convergence/weak_strat.jl") end
97+
end
98+
99+
if !is_APPVEYOR && GROUP == "WeakConvergence6"
100+
@time @safetestset "Roessler weak SRK diagonal Tests" begin include("weak_convergence/srk_weak_diagonal_final.jl") end
101+
end
102+
103+
if !is_APPVEYOR && GROUP == "OOPWeakConvergence"
104+
@time @safetestset "OOP Weak Convergence Tests" begin include("weak_convergence/oop_weak.jl") end
105+
@time @safetestset "Additive Weak Convergence Tests" begin include("weak_convergence/additive_weak.jl") end
106+
end
107+
108+
if !is_APPVEYOR && GROUP == "IIPWeakConvergence"
109+
#activate_gpu_env()
110+
@time @safetestset "IIP Weak Convergence Tests" begin include("weak_convergence/iip_weak.jl") end
111+
end
112+
113+
if !is_APPVEYOR && GROUP == "SROCKC2WeakConvergence"
114+
#activate_gpu_env()
115+
@time @safetestset "SROCKC2 Weak Convergence Tests" begin include("weak_convergence/weak_srockc2.jl") end
116+
end
117+
118+
if !is_APPVEYOR && GROUP == "WeakAdaptiveCPU"
119+
@time @safetestset "CPU Weak adaptive step size Brusselator " begin include("adaptive/sde_weak_brusselator_adaptive.jl") end
120+
@time @safetestset "CPU Weak adaptive" begin include("adaptive/sde_weak_adaptive.jl") end
121+
end
122+
123+
if !is_APPVEYOR && GROUP == "WeakAdaptiveGPU"
124+
activate_gpu_env()
125+
@time @safetestset "GPU Weak adaptive step size scalar noise SDE" begin include("gpu/sde_weak_scalar_adaptive_gpu.jl") end
126+
@time @safetestset "GPU Weak adaptive" begin include("gpu/sde_weak_adaptive_gpu.jl") end
127+
end
128+
129+
if !is_APPVEYOR && GROUP == "Multithreaded"
130+
@time @safetestset "Mulithreaded Jump Thread Safety Tests" begin include("multithreaded_jump_test.jl") end
131+
end
132+
133+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ end
99

1010
const LONGER_TESTS = false
1111

12-
const GROUP = get(ENV, "GROUP", "All")
12+
const GROUP = "Interface" # get(ENV, "GROUP", "All")
1313

1414
const is_APPVEYOR = Sys.iswindows() && haskey(ENV,"APPVEYOR")
1515

0 commit comments

Comments
 (0)