Skip to content

Commit cb669ab

Browse files
committed
push in the right direction, still some issues with adaptivity
1 parent b4447c0 commit cb669ab

File tree

2 files changed

+60
-48
lines changed

2 files changed

+60
-48
lines changed

src/caches/implicit_split_step_caches.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function alg_cache(alg::ISSEulerHeun, prob, u, ΔW, ΔZ, p, rate_prototype, nois
6464
gtmp2 = zero(rate_prototype)
6565

6666
if is_diagonal_noise(prob)
67-
gtmp3 = gtmp2
67+
gtmp3 = copy(gtmp2)
6868
dW_cache = nothing
6969
else
7070
gtmp3 = zero(noise_rate_prototype)

src/perform_step/implicit_split_step.jl

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
1-
@muladd function perform_step!(integrator,
2-
cache::Union{ISSEMConstantCache,
3-
ISSEulerHeunConstantCache})
4-
@unpack t,dt,uprev,u,p,f = integrator
1+
@muladd function perform_step!(integrator, cache::Union{ISSEMConstantCache, ISSEulerHeunConstantCache})
2+
3+
@unpack t, dt, uprev, u, p, f = integrator
54
@unpack nlsolver = cache
65
alg = unwrap_alg(integrator, true)
76
theta = alg.theta
8-
alg.symplectic ? a = dt/2 : a = theta*dt
7+
alg.symplectic ? a = dt / 2 : a = theta * dt
98
OrdinaryDiffEq.markfirststage!(nlsolver)
109

1110
# TODO: Stochastic extrapolants?
1211
u = uprev
1312

1413
repeat_step = false
1514

16-
L = integrator.g(uprev,p,t)
17-
ftmp = integrator.f(uprev,p,t)
15+
L = integrator.g(uprev, p, t)
16+
ftmp = integrator.f(uprev, p, t)
1817

1918
if alg.symplectic
2019
z = zero(u) # constant extrapolation, justified by ODE IM
2120
else
22-
z = dt*ftmp # linear extrapolation
21+
z = dt * ftmp # linear extrapolation
2322
end
2423
nlsolver.z = z
2524

@@ -29,7 +28,7 @@
2928
#u = uprev + z/2
3029
tmp = uprev
3130
else
32-
tmp = uprev + dt*(1-theta)*ftmp
31+
tmp = uprev + dt * (1 - theta) * ftmp
3332
end
3433
nlsolver.tmp = tmp
3534

@@ -39,56 +38,71 @@
3938
if alg.symplectic
4039
u = tmp + z
4140
else
42-
u = tmp + theta*z
41+
u = tmp + theta * z
4342
end
4443

45-
gtmp = L.*integrator.W.dW
44+
if !is_diagonal_noise(integrator.sol.prob)
45+
gtmp = L * integrator.W.dW
46+
else
47+
gtmp = L .* integrator.W.dW
48+
end
4649

4750
if typeof(cache) <: ISSEulerHeunConstantCache
4851
utilde = u + gtmp
49-
gtmp = ((integrator.g(utilde,p,t) + L)/2)*integrator.W.dW
52+
if !is_diagonal_noise(integrator.sol.prob)
53+
gtmp = ((integrator.g(utilde, p, t) + L) / 2) * integrator.W.dW
54+
else
55+
gtmp = ((integrator.g(utilde, p, t) + L) / 2) .* integrator.W.dW
56+
end
5057
end
5158

5259
u += gtmp
5360

5461
if integrator.opts.adaptive
5562
if has_Wfact(f)
5663
# This means the Jacobian was never computed!
57-
J = f.jac(uprev,p,t)
64+
J = f.jac(uprev, p, t)
5865
else
5966
J = OrdinaryDiffEq.calc_J(integrator, nlsolver.cache)
6067
end
61-
Ed = dt*(dt*J*ftmp)/2
68+
Ed = dt * (dt * J * ftmp) / 2
6269

6370
if typeof(cache) <: ISSEMConstantCache
6471
K = @.. uprev + dt * ftmp
65-
utilde = @.. K + integrator.sqdt * L
66-
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
67-
En = ggprime .* (integrator.W.dW.^2 .- dt)./2
72+
utilde = @.. K + integrator.sqdt * L
73+
ggprime = (integrator.g(utilde, p, t) .- L) ./ (integrator.sqdt)
74+
if !is_diagonal_noise(integrator.sol.prob)
75+
En = ggprime * (integrator.W.dW .^ 2 .- dt) ./ 2
76+
else
77+
En = ggprime .* (integrator.W.dW .^ 2 .- dt) ./ 2
78+
end
6879
elseif typeof(cache) <: ISSEulerHeunConstantCache
69-
utilde = uprev + L*integrator.sqdt
70-
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
71-
En = ggprime.*(integrator.W.dW.^2)./2
80+
utilde = @.. uprev + L * integrator.sqdt
81+
ggprime = (integrator.g(utilde, p, t) .- L) ./ (integrator.sqdt)
82+
if !is_diagonal_noise(integrator.sol.prob)
83+
En = ggprime * (integrator.W.dW .^ 2) ./ 2
84+
else
85+
En = ggprime .* (integrator.W.dW .^ 2) ./ 2
86+
end
7287
end
7388

7489
resids = calculate_residuals(Ed, En, uprev, u, integrator.opts.abstol,
75-
integrator.opts.reltol, integrator.opts.delta,
76-
integrator.opts.internalnorm,t)
77-
integrator.EEst = integrator.opts.internalnorm(resids,t)
90+
integrator.opts.reltol, integrator.opts.delta,
91+
integrator.opts.internalnorm, t)
92+
integrator.EEst = integrator.opts.internalnorm(resids, t)
7893
end
7994

8095
integrator.u = u
8196
end
8297

83-
@muladd function perform_step!(integrator, cache::Union{ISSEMCache,
84-
ISSEulerHeunCache})
85-
@unpack t,dt,uprev,u,p,f = integrator
86-
@unpack gtmp,gtmp2,dW_cache,nlsolver,k,dz = cache
87-
@unpack z,tmp = nlsolver
98+
@muladd function perform_step!(integrator, cache::Union{ISSEMCache, ISSEulerHeunCache})
99+
@unpack t, dt, uprev, u, p, f = integrator
100+
@unpack gtmp, gtmp2, dW_cache, nlsolver, k, dz = cache
101+
@unpack z, tmp = nlsolver
88102

89103
J = (OrdinaryDiffEq.isnewton(nlsolver) ? nlsolver.cache.J : nothing)
90104
alg = unwrap_alg(integrator, true)
91-
alg.symplectic ? a = dt/2 : a = alg.theta*dt
105+
alg.symplectic ? a = dt / 2 : a = alg.theta * dt
92106
dW = integrator.W.dW
93107
mass_matrix = integrator.f.mass_matrix
94108
theta = alg.theta
@@ -97,20 +111,28 @@ end
97111
repeat_step = false
98112

99113
if integrator.success_iter > 0 && !integrator.u_modified && alg.extrapolant == :interpolant
100-
current_extrapolant!(u,t+dt,integrator)
114+
current_extrapolant!(u, t + dt, integrator)
101115
elseif alg.extrapolant == :linear
102-
@.. u = uprev + integrator.fsalfirst*dt
116+
@.. u = uprev + integrator.fsalfirst * dt
103117
else # :constant
104-
copyto!(u,uprev)
118+
copyto!(u, uprev)
105119
end
106120

107-
integrator.f(tmp,uprev,p,t)
121+
integrator.f(tmp, uprev, p, t)
108122
integrator.g(gtmp, uprev, p, t)
109123

110124
if alg.symplectic
111125
@.. z = zero(eltype(u)) # Justified by ODE solvers, constrant extrapolation when IM
112126
else
113-
@.. z = dt*tmp # linear extrapolation
127+
@.. z = dt * tmp # linear extrapolation
128+
end
129+
130+
# Handle noise computations
131+
132+
if is_diagonal_noise(integrator.sol.prob)
133+
@.. gtmp2 = gtmp * dW
134+
else
135+
mul!(gtmp2, gtmp, dW)
114136
end
115137

116138
###
@@ -144,8 +166,8 @@ end
144166
if typeof(cache) <: ISSEMCache
145167

146168
if !is_diagonal_noise(integrator.sol.prob)
147-
integrator.g(gtmp2, z, p, t)
148-
g_sized2 = norm(gtmp2, 2)
169+
integrator.g(gtmp, z, p, t)
170+
g_sized2 = norm(gtmp, 2)
149171
@.. dW_cache = dW .^ 2 - dt
150172
diff_tmp = integrator.opts.internalnorm(dW_cache, t)
151173
En = (g_sized2 - g_sized) / (2integrator.sqdt) * diff_tmp
@@ -191,18 +213,9 @@ end
191213
end
192214

193215
##############################################################################
194-
195-
# Handle noise computations
196-
197-
if is_diagonal_noise(integrator.sol.prob)
198-
@.. gtmp2 = gtmp * dW
199-
else
200-
mul!(gtmp2, gtmp, dW)
201-
end
202-
203216
if typeof(cache) <: ISSEulerHeunCache
204217
gtmp3 = cache.gtmp3
205-
@.. z = uprev + gtmp2
218+
@.. z = u + gtmp2
206219
integrator.g(gtmp3, z, p, t)
207220
@.. gtmp = (gtmp3 + gtmp) / 2
208221
if is_diagonal_noise(integrator.sol.prob)
@@ -211,7 +224,6 @@ end
211224
mul!(gtmp2, gtmp, dW)
212225
end
213226
end
214-
215227
@.. u += gtmp2
216228

217229
##############################################################################

0 commit comments

Comments
 (0)