Skip to content

Commit 0a1ffe8

Browse files
Merge pull request #483 from SciML/ISSEM
adaptive stepping `ISSEM` fixes
2 parents b9806da + 1a45b42 commit 0a1ffe8

File tree

2 files changed

+178
-139
lines changed

2 files changed

+178
-139
lines changed

src/caches/implicit_split_step_caches.jl

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,38 @@
66
gtmp2::rateType
77
nlsolver::N
88
dW_cache::randType
9+
k::uType
10+
dz::uType
911
end
1012

11-
function alg_cache(alg::ISSEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,
12-
::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
13-
γ, c = alg.theta,zero(t)
14-
nlsolver = OrdinaryDiffEq.build_nlsolver(alg,u,uprev,p,t,dt,f,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,Val(true))
13+
function alg_cache(alg::ISSEM, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype,
14+
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
15+
γ, c = alg.theta, zero(t)
16+
nlsolver = OrdinaryDiffEq.build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true))
1517
fsalfirst = zero(rate_prototype)
1618
gtmp = zero(noise_rate_prototype)
1719
if is_diagonal_noise(prob)
18-
gtmp2 = gtmp
20+
gtmp2 = copy(gtmp)
1921
dW_cache = nothing
2022
else
2123
gtmp2 = zero(rate_prototype)
2224
dW_cache = zero(ΔW)
2325
end
2426

25-
ISSEMCache(u,uprev,fsalfirst,gtmp,gtmp2,nlsolver,dW_cache)
27+
k = zero(u)
28+
dz = zero(u)
29+
30+
ISSEMCache(u, uprev, fsalfirst, gtmp, gtmp2, nlsolver, dW_cache, k, dz)
2631
end
2732

2833
mutable struct ISSEMConstantCache{N} <: StochasticDiffEqConstantCache
2934
nlsolver::N
3035
end
3136

32-
function alg_cache(alg::ISSEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,
33-
::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
34-
γ, c = alg.theta,zero(t)
35-
nlsolver = OrdinaryDiffEq.build_nlsolver(alg,u,uprev,p,t,dt,f,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,Val(false))
37+
function alg_cache(alg::ISSEM, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype,
38+
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
39+
γ, c = alg.theta, zero(t)
40+
nlsolver = OrdinaryDiffEq.build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
3641
ISSEMConstantCache(nlsolver)
3742
end
3843

@@ -45,35 +50,40 @@ end
4550
gtmp3::noiseRateType
4651
nlsolver::N
4752
dW_cache::randType
53+
k::uType
54+
dz::uType
4855
end
4956

50-
function alg_cache(alg::ISSEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,
51-
::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
52-
γ, c = alg.theta,zero(t)
53-
nlsolver = OrdinaryDiffEq.build_nlsolver(alg,u,uprev,p,t,dt,f,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,Val(true))
57+
function alg_cache(alg::ISSEulerHeun, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype,
58+
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
59+
γ, c = alg.theta, zero(t)
60+
nlsolver = OrdinaryDiffEq.build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true))
5461
fsalfirst = zero(rate_prototype)
5562

5663
gtmp = zero(noise_rate_prototype)
5764
gtmp2 = zero(rate_prototype)
5865

5966
if is_diagonal_noise(prob)
60-
gtmp3 = gtmp2
61-
dW_cache = nothing
67+
gtmp3 = copy(gtmp2)
68+
dW_cache = nothing
6269
else
63-
gtmp3 = zero(noise_rate_prototype)
64-
dW_cache = zero(ΔW)
70+
gtmp3 = zero(noise_rate_prototype)
71+
dW_cache = zero(ΔW)
6572
end
6673

67-
ISSEulerHeunCache(u,uprev,fsalfirst,gtmp,gtmp2,gtmp3,nlsolver,dW_cache)
74+
k = zero(u)
75+
dz = zero(u)
76+
77+
ISSEulerHeunCache(u, uprev, fsalfirst, gtmp, gtmp2, gtmp3, nlsolver, dW_cache, k, dz)
6878
end
6979

7080
mutable struct ISSEulerHeunConstantCache{N} <: StochasticDiffEqConstantCache
7181
nlsolver::N
7282
end
7383

74-
function alg_cache(alg::ISSEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,
75-
::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
76-
γ, c = alg.theta,zero(t)
77-
nlsolver = OrdinaryDiffEq.build_nlsolver(alg,u,uprev,p,t,dt,f,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,Val(false))
84+
function alg_cache(alg::ISSEulerHeun, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype,
85+
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
86+
γ, c = alg.theta, zero(t)
87+
nlsolver = OrdinaryDiffEq.build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
7888
ISSEulerHeunConstantCache(nlsolver)
7989
end

0 commit comments

Comments
 (0)