Skip to content

Commit fe0e952

Browse files
committed
adaptive stepping ISSEM fixes
1 parent b9806da commit fe0e952

File tree

2 files changed

+113
-99
lines changed

2 files changed

+113
-99
lines changed

src/caches/implicit_split_step_caches.jl

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,38 @@
66
gtmp2::rateType
77
nlsolver::N
88
dW_cache::randType
9+
k::uType
10+
dz::uType
911
end
1012

11-
function alg_cache(alg::ISSEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,
12-
::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
13-
γ, c = alg.theta,zero(t)
14-
nlsolver = OrdinaryDiffEq.build_nlsolver(alg,u,uprev,p,t,dt,f,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,Val(true))
13+
function alg_cache(alg::ISSEM, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype,
14+
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
15+
γ, c = alg.theta, zero(t)
16+
nlsolver = OrdinaryDiffEq.build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true))
1517
fsalfirst = zero(rate_prototype)
1618
gtmp = zero(noise_rate_prototype)
1719
if is_diagonal_noise(prob)
18-
gtmp2 = gtmp
20+
gtmp2 = copy(gtmp)
1921
dW_cache = nothing
2022
else
2123
gtmp2 = zero(rate_prototype)
2224
dW_cache = zero(ΔW)
2325
end
2426

25-
ISSEMCache(u,uprev,fsalfirst,gtmp,gtmp2,nlsolver,dW_cache)
27+
k = zero(u)
28+
dz = zero(u)
29+
30+
ISSEMCache(u, uprev, fsalfirst, gtmp, gtmp2, nlsolver, dW_cache, k, dz)
2631
end
2732

2833
mutable struct ISSEMConstantCache{N} <: StochasticDiffEqConstantCache
2934
nlsolver::N
3035
end
3136

32-
function alg_cache(alg::ISSEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,
33-
::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
34-
γ, c = alg.theta,zero(t)
35-
nlsolver = OrdinaryDiffEq.build_nlsolver(alg,u,uprev,p,t,dt,f,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,Val(false))
37+
function alg_cache(alg::ISSEM, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype,
38+
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
39+
γ, c = alg.theta, zero(t)
40+
nlsolver = OrdinaryDiffEq.build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
3641
ISSEMConstantCache(nlsolver)
3742
end
3843

@@ -45,35 +50,40 @@ end
4550
gtmp3::noiseRateType
4651
nlsolver::N
4752
dW_cache::randType
53+
k::uType
54+
dz::uType
4855
end
4956

50-
function alg_cache(alg::ISSEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,
51-
::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
52-
γ, c = alg.theta,zero(t)
53-
nlsolver = OrdinaryDiffEq.build_nlsolver(alg,u,uprev,p,t,dt,f,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,Val(true))
57+
function alg_cache(alg::ISSEulerHeun, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype,
58+
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
59+
γ, c = alg.theta, zero(t)
60+
nlsolver = OrdinaryDiffEq.build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true))
5461
fsalfirst = zero(rate_prototype)
5562

5663
gtmp = zero(noise_rate_prototype)
5764
gtmp2 = zero(rate_prototype)
5865

5966
if is_diagonal_noise(prob)
60-
gtmp3 = gtmp2
61-
dW_cache = nothing
67+
gtmp3 = gtmp2
68+
dW_cache = nothing
6269
else
63-
gtmp3 = zero(noise_rate_prototype)
64-
dW_cache = zero(ΔW)
70+
gtmp3 = zero(noise_rate_prototype)
71+
dW_cache = zero(ΔW)
6572
end
6673

67-
ISSEulerHeunCache(u,uprev,fsalfirst,gtmp,gtmp2,gtmp3,nlsolver,dW_cache)
74+
k = zero(u)
75+
dz = zero(u)
76+
77+
ISSEulerHeunCache(u, uprev, fsalfirst, gtmp, gtmp2, gtmp3, nlsolver, dW_cache, k, dz)
6878
end
6979

7080
mutable struct ISSEulerHeunConstantCache{N} <: StochasticDiffEqConstantCache
7181
nlsolver::N
7282
end
7383

74-
function alg_cache(alg::ISSEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,
75-
::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
76-
γ, c = alg.theta,zero(t)
77-
nlsolver = OrdinaryDiffEq.build_nlsolver(alg,u,uprev,p,t,dt,f,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,Val(false))
84+
function alg_cache(alg::ISSEulerHeun, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype,
85+
::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
86+
γ, c = alg.theta, zero(t)
87+
nlsolver = OrdinaryDiffEq.build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
7888
ISSEulerHeunConstantCache(nlsolver)
7989
end

src/perform_step/implicit_split_step.jl

Lines changed: 80 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,16 @@
5252
u += gtmp
5353

5454
if integrator.opts.adaptive
55-
56-
if !OrdinaryDiffEq.isnewton(nlsolver)
57-
is_compos = isa(integrator.alg, StochasticDiffEqCompositeAlgorithm)
55+
if has_Wfact(f)
56+
# This means the Jacobian was never computed!
57+
J = f.jac(uprev,p,t)
58+
else
5859
J = OrdinaryDiffEq.calc_J(integrator, nlsolver.cache)
5960
end
60-
Ed = dt*(J*ftmp)/2
61+
du2 = integrator.f(uprev + dt * ftmp, p, t + dt)
62+
Ed = dt*(dt*J*ftmp)/2
6163

62-
if typeof(cache) <: SplitStepEulerConstantCache
64+
if typeof(cache) <: ISSEMConstantCache
6365
K = @.. uprev + dt * ftmp
6466
utilde = @.. K + integrator.sqdt * L
6567
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
@@ -82,9 +84,9 @@ end
8284
@muladd function perform_step!(integrator, cache::Union{ISSEMCache,
8385
ISSEulerHeunCache})
8486
@unpack t,dt,uprev,u,p,f = integrator
85-
@unpack gtmp,gtmp2,dW_cache,nlsolver = cache
87+
@unpack gtmp,gtmp2,dW_cache,nlsolver,k,dz = cache
8688
@unpack z,tmp = nlsolver
87-
@unpack k,dz = nlsolver.cache # alias to reduce memory
89+
8890
J = (OrdinaryDiffEq.isnewton(nlsolver) ? nlsolver.cache.J : nothing)
8991
alg = unwrap_alg(integrator, true)
9092
alg.symplectic ? a = dt/2 : a = alg.theta*dt
@@ -104,19 +106,79 @@ end
104106
end
105107

106108
integrator.f(tmp,uprev,p,t)
109+
integrator.g(gtmp, uprev, p, t)
107110

108111
if alg.symplectic
109112
@.. z = zero(eltype(u)) # Justified by ODE solvers, constrant extrapolation when IM
110113
else
111114
@.. z = dt*tmp # linear extrapolation
112115
end
113116

117+
###
118+
# adaptivity part
119+
if integrator.opts.adaptive
120+
121+
if has_Wfact(f)
122+
# This means the Jacobian was never computed!
123+
f.jac(J, uprev, p, t)
124+
else
125+
OrdinaryDiffEq.calc_J!(J, integrator, nlsolver.cache)
126+
end
127+
128+
mul!(vec(z), J, vec(tmp))
129+
@.. k = dt * dt * z / 2
130+
# k is Ed
131+
# dz is En
132+
133+
if !is_diagonal_noise(integrator.sol.prob)
134+
g_sized = norm(gtmp, 2)
135+
else
136+
g_sized = gtmp
137+
end
138+
# z is utilde above
139+
if typeof(cache) <: ISSEMCache
140+
@.. z = uprev + dt * tmp + integrator.sqdt * g_sized
141+
elseif typeof(cache) <: ISSEulerHeunCache
142+
@.. z = uprev + integrator.sqdt * g_sized
143+
end
144+
145+
if typeof(cache) <: ISSEMCache
146+
147+
if !is_diagonal_noise(integrator.sol.prob)
148+
integrator.g(gtmp2, z, p, t)
149+
g_sized2 = norm(gtmp2, 2)
150+
@.. dW_cache = dW .^ 2 - dt
151+
diff_tmp = integrator.opts.internalnorm(dW_cache, t)
152+
En = (g_sized2 - g_sized) / (2integrator.sqdt) * diff_tmp
153+
@.. dz = En
154+
else
155+
integrator.g(gtmp2, z, p, t)
156+
g_sized2 = gtmp2
157+
@.. dz = (g_sized2 - g_sized) / (2integrator.sqdt) * (dW .^ 2 - dt)
158+
end
159+
elseif typeof(cache) <: ISSEulerHeunCache
160+
if !is_diagonal_noise(integrator.sol.prob)
161+
integrator.g(gtmp, z, p, t)
162+
g_sized2 = norm(gtmp, 2)
163+
@.. dW_cache = dW .^ 2
164+
diff_tmp = integrator.opts.internalnorm(dW_cache, t)
165+
En = (g_sized2 - g_sized) / (2integrator.sqdt) * diff_tmp
166+
@.. dz = En
167+
else
168+
integrator.g(gtmp2, z, p, t)
169+
g_sized2 = gtmp2
170+
@.. dz = (g_sized2 - g_sized) / (2integrator.sqdt) * (dW .^ 2)
171+
end
172+
end
173+
end
174+
###
175+
114176
if alg.symplectic
115177
#@.. u = uprev + z/2
116178
@.. tmp = uprev
117179
else
118180
#@.. u = uprev + dt*(1-theta)*tmp + theta*z
119-
@.. tmp = uprev + dt*(1-theta)*tmp
181+
@.. tmp = uprev + dt * (1 - theta) * tmp
120182
end
121183
nlsolver.c = a
122184
z = OrdinaryDiffEq.nlsolve!(nlsolver, integrator, cache, repeat_step)
@@ -126,96 +188,38 @@ end
126188
@.. u = uprev + z
127189
else
128190
#@.. u = uprev + dt*(1-theta)*tmp + theta*z
129-
@.. u = tmp + theta*z
191+
@.. u = tmp + theta * z
130192
end
131193

132194
##############################################################################
133195

134196
# Handle noise computations
135197

136-
integrator.g(gtmp,uprev,p,t)
137-
138-
139198
if is_diagonal_noise(integrator.sol.prob)
140-
@.. gtmp2 = gtmp*dW
199+
@.. gtmp2 = gtmp * dW
141200
else
142-
mul!(gtmp2,gtmp,dW)
201+
mul!(gtmp2, gtmp, dW)
143202
end
144203

145204
if typeof(cache) <: ISSEulerHeunCache
146205
gtmp3 = cache.gtmp3
147206
@.. z = uprev + gtmp2
148-
integrator.g(gtmp3,z,p,t)
149-
@.. gtmp = (gtmp3 + gtmp)/2
207+
integrator.g(gtmp3, z, p, t)
208+
@.. gtmp = (gtmp3 + gtmp) / 2
150209
if is_diagonal_noise(integrator.sol.prob)
151-
@.. gtmp2 = gtmp*dW
210+
@.. gtmp2 = gtmp * dW
152211
else
153-
mul!(gtmp2,gtmp,dW)
212+
mul!(gtmp2, gtmp, dW)
154213
end
155214
end
156215

157216
@.. u += gtmp2
158217

159218
##############################################################################
160-
161219
if integrator.opts.adaptive
162-
163-
if has_Wfact(f)
164-
# This means the Jacobian was never computed!
165-
f.jac(J,uprev,p,t)
166-
end
167-
168-
mul!(vec(z),J,vec(tmp))
169-
@.. k = dt*dt*z/2
170-
171-
# k is Ed
172-
# dz is En
173-
174-
if !is_diagonal_noise(integrator.sol.prob)
175-
g_sized = norm(gtmp,2)
176-
else
177-
g_sized = gtmp
178-
end
179-
180-
if typeof(cache) <: ISSEMCache
181-
@.. z = uprev + dt*tmp + integrator.sqdt * g_sized
182-
183-
if !is_diagonal_noise(integrator.sol.prob)
184-
integrator.g(gtmp,z,p,t)
185-
g_sized2 = norm(gtmp,2)
186-
@.. dW_cache = dW.^2 - dt
187-
diff_tmp = integrator.opts.internalnorm(dW_cache,t)
188-
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
189-
@.. dz = En
190-
else
191-
integrator.g(gtmp2,z,p,t)
192-
g_sized2 = gtmp2
193-
@.. dz = (g_sized2-g_sized)/(2integrator.sqdt)*(dW.^2 - dt)
194-
end
195-
196-
elseif typeof(cache) <: ISSEulerHeunCache
197-
@.. z = uprev + integrator.sqdt * g_sized
198-
199-
if !is_diagonal_noise(integrator.sol.prob)
200-
integrator.g(gtmp,z,p,t)
201-
g_sized2 = norm(gtmp,2)
202-
@.. dW_cache = dW.^2
203-
diff_tmp = integrator.opts.internalnorm(dW_cache,t)
204-
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
205-
@.. dz = En
206-
else
207-
integrator.g(gtmp2,z,p,t)
208-
g_sized2 = gtmp2
209-
@.. dz = (g_sized2-g_sized)/(2integrator.sqdt)*(dW.^2)
210-
end
211-
212-
213-
end
214-
215220
calculate_residuals!(tmp, k, dz, uprev, u, integrator.opts.abstol,
216-
integrator.opts.reltol, integrator.opts.delta,
217-
integrator.opts.internalnorm,t)
218-
integrator.EEst = integrator.opts.internalnorm(tmp,t)
219-
221+
integrator.opts.reltol, integrator.opts.delta,
222+
integrator.opts.internalnorm, t)
223+
integrator.EEst = integrator.opts.internalnorm(tmp, t)
220224
end
221225
end

0 commit comments

Comments
 (0)