Skip to content

Commit 5220126

Browse files
in-place fixed
1 parent 5acab16 commit 5220126

File tree

5 files changed

+125
-186
lines changed

5 files changed

+125
-186
lines changed

lib/OrdinaryDiffEqFIRK/src/firk_addsteps.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function _ode_addsteps!(integrator, cache::RadauIIA3Cache, repeat_step = false)
108108
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst = integrator
109109
@unpack T11, T12, T21, T22, TI11, TI12, TI21, TI22 = cache.tab
110110
@unpack c1, c2, α, β, e1, e2 = cache.tab
111-
@unpack κ, cont1, cont2 = cache
111+
@unpack κ = cache
112112
@unpack z1, z2, w1, w2,
113113
dw12, cubuff,
114114
k, k2, fw1, fw2,
@@ -127,14 +127,24 @@ function _ode_addsteps!(integrator, cache::RadauIIA3Cache, repeat_step = false)
127127
integrator.stats.nw += 1
128128
end
129129

130-
#better initial guess
131-
uzero = zero(eltype(z1))
132-
@. z1 = uzero
133-
@. z2 = uzero
134-
@. w1 = uzero
135-
@. w2 = uzero
136-
@. cache.cont1 = uzero
137-
@. cache.cont2 = uzero
130+
c1m1 = c1 - 1
131+
if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
132+
cache.dtprev = one(cache.dtprev)
133+
uzero = zero(eltype(u))
134+
@.. broadcast=false z1=uzero
135+
@.. broadcast=false z2=uzero
136+
@.. broadcast=false w1=uzero
137+
@.. broadcast=false w2=uzero
138+
@.. broadcast=false integrator.k[3]=uzero
139+
@.. broadcast=false integrator.k[4]=uzero
140+
else
141+
c2′ = dt / cache.dtprev
142+
c1′ = c1 * c2′
143+
@.. broadcast=false z1=c1′ * (k[3] + (c1′ - c1m1) * k[4])
144+
@.. broadcast=false z2=c2′ * (k[3] + (c2′ - c1m1) * k[4])
145+
@.. broadcast=false w1=TI11 * z1 + TI12 * z2
146+
@.. broadcast=false w2=TI21 * z1 + TI22 * z2
147+
end
138148

139149
# Newton iteration
140150
local ndw
@@ -233,6 +243,14 @@ function _ode_addsteps!(integrator, cache::RadauIIA3Cache, repeat_step = false)
233243
@. u = uprev + z2
234244
step_limiter!(u, integrator, p, t + dt)
235245

246+
if integrator.EEst <= oneunit(integrator.EEst)
247+
cache.dtprev = dt
248+
if alg.extrapolant != :constant
249+
integrator.k[3] = (z1 - z2) / c1m1
250+
integrator.k[4] = integrator.k[3] - (z1 / c1)
251+
end
252+
end
253+
236254
f(fsallast, u, p, t + dt)
237255
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
238256
return
@@ -384,7 +402,7 @@ function _ode_addsteps!(integrator, cache::RadauIIA5Cache, repeat_step = false)
384402
@unpack t, dt, uprev, u, f, p, fsallast, fsalfirst, k = integrator
385403
@unpack T11, T12, T13, T21, T22, T23, T31, TI11, TI12, TI13, TI21, TI22, TI23, TI31, TI32, TI33 = cache.tab
386404
@unpack c1, c2, γ, α, β, e1, e2, e3 = cache.tab
387-
@unpack κ, cont1, cont2, cont3 = cache
405+
@unpack κ = cache
388406
@unpack z1, z2, z3, w1, w2, w3,
389407
dw1, ubuff, dw23, cubuff,
390408
k, k2, k3, fw1, fw2, fw3,
@@ -396,8 +414,8 @@ alg = unwrap_alg(integrator, true)
396414
mass_matrix = integrator.f.mass_matrix
397415

398416
# precalculations
399-
c1m1 = (c1 - 1)*dt
400-
c2m1 = (c2 - 1)*dt
417+
c1m1 = c1 - 1
418+
c2m1 = c2 - 1
401419
c1mc2 = c1 - c2
402420
γdt, αdt, βdt = γ / dt, α / dt, β / dt
403421
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
@@ -422,7 +440,7 @@ if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
422440
@.. broadcast=false integrator.k[4] = uzero
423441
@.. broadcast=false integrator.k[5] = uzero
424442
else
425-
c3′ = dt
443+
c3′ = dt / cache.dtprev
426444
c1′ = c1 * c3′
427445
c2′ = c2 * c3′
428446
@.. broadcast=false z1=c1′ * (k[3] + (c1′ - c2m1) * (k[4] + (c1′ - c1m1) * k[5]))
@@ -565,8 +583,8 @@ step_limiter!(u, integrator, p, t + dt)
565583
if integrator.EEst <= oneunit(integrator.EEst)
566584
cache.dtprev = dt
567585
if alg.extrapolant != :constant
568-
integrator.k[3] = (z2 - z3) / (dt * c2m1)
569-
@.. tmp=(z1 - z2) / (dt * c1mc2)
586+
integrator.k[3] = (z2 - z3) / c2m1
587+
@.. tmp=(z1 - z2) / c1mc2
570588
integrator.k[4] = (tmp - integrator.k[3]) / c1m1
571589
integrator.k[5] = integrator.k[4] - (tmp - z1 / c1) / c2
572590
end
@@ -807,7 +825,7 @@ function _ode_addsteps!(integrator, cache::RadauIIA9Cache, repeat_step = false)
807825
@unpack T11, T12, T13, T14, T15, T21, T22, T23, T24, T25, T31, T32, T33, T34, T35, T41, T42, T43, T44, T45, T51 = cache.tab #= T52 = 1, T53 = 0, T54 = 1, T55 = 0=#
808826
@unpack TI11, TI12, TI13, TI14, TI15, TI21, TI22, TI23, TI24, TI25, TI31, TI32, TI33, TI34, TI35, TI41, TI42, TI43, TI44, TI45, TI51, TI52, TI53, TI54, TI55 = cache.tab
809827
@unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache.tab
810-
@unpack κ= cache
828+
@unpack κ = cache
811829
@unpack z1, z2, z3, z4, z5, w1, w2, w3, w4, w5 = cache
812830
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
813831
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache

lib/OrdinaryDiffEqFIRK/src/firk_caches.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ mutable struct RadauIIA3Cache{uType, cuType, uNoUnitsType, rateType, JType, W1Ty
4141
w2::uType
4242
dw12::cuType
4343
cubuff::cuType
44-
cont1::uType
45-
cont2::uType
4644
du1::rateType
4745
fsalfirst::rateType
4846
k::rateType
@@ -86,8 +84,6 @@ function alg_cache(alg::RadauIIA3, u, rate_prototype, ::Type{uEltypeNoUnits},
8684
recursivefill!(dw12, false)
8785
cubuff = similar(u, Complex{eltype(u)})
8886
recursivefill!(cubuff, false)
89-
cont1 = zero(u)
90-
cont2 = zero(u)
9187

9288
fsalfirst = zero(rate_prototype)
9389
k = zero(rate_prototype)
@@ -118,7 +114,7 @@ function alg_cache(alg::RadauIIA3, u, rate_prototype, ::Type{uEltypeNoUnits},
118114

119115
RadauIIA3Cache(u, uprev,
120116
z1, z2, w1, w2,
121-
dw12, cubuff, cont1, cont2,
117+
dw12, cubuff,
122118
du1, fsalfirst, k, k2, fw1, fw2,
123119
J, W1,
124120
uf, tab, κ, one(uToltype), 10000,
@@ -462,7 +458,7 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
462458

463459
RadauIIA9Cache(u, uprev,
464460
z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
465-
dw1, ubuff, dw23, dw45, cubuff1, cubuff2,
461+
dw1, ubuff, dw23, dw45, cubuff1, cubuff2,
466462
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
467463
J, W1, W2, W3,
468464
uf, tab, κ, one(uToltype), 10000,

lib/OrdinaryDiffEqFIRK/src/firk_interpolants.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ FIRK_WITH_INTERPOLATIONS = Union{RadauIIA3ConstantCache, RadauIIA3Cache, RadauII
1010
@.. y₀ + Θ * (cont1 +- c1m1) * cont2)
1111
end
1212

13+
@muladd function _ode_interpolant!(
14+
out, Θ, dt, y₀, y₁, k, cache::Union{RadauIIA3ConstantCache, RadauIIA3Cache},
15+
idxs::Nothing, T::Type{Val{0}}, differential_vars)
16+
@unpack c1 = cache.tab
17+
c1m1 = c1 - 1
18+
Θdt = 1 - Θ
19+
@.. out = y₁ - Θdt * (k[3] - (Θdt + c1m1) * k[4])
20+
end
21+
1322
@muladd function _ode_interpolant(
1423
Θ, dt, y₀, y₁, k, cache::Union{RadauIIA5ConstantCache, RadauIIA5Cache},
1524
idxs::Nothing, T::Type{Val{0}}, differential_vars)
@@ -26,9 +35,9 @@ end
2635
idxs::Nothing, T::Type{Val{0}}, differential_vars)
2736
@unpack c1, c2 = cache.tab
2837
@unpack dtprev = cache
29-
c1m1 = (c1 - 1) * dt
30-
c2m1 = (c2 - 1) * dt
31-
Θdt = (1 - Θ)*dt
38+
c1m1 = c1 - 1
39+
c2m1 = c2 - 1
40+
Θdt = 1 - Θ
3241
@.. out = y₁ - Θdt * (k[3] - (Θdt + c2m1) * (k[4] - (Θdt + c1m1) * k[5]))
3342
end
3443

@@ -41,7 +50,6 @@ end
4150
c3m1 = c3 - 1
4251
c4m1 = c4 - 1
4352
Θdt = 1 - Θ
44-
print("yes")
4553
@.. y₁ - Θdt * (k[3] - (Θdt + c4m1) * (k[4] - (Θdt + c3m1) * (k[5] - (Θdt + c2m1) * (k[6] - (Θdt + c1m1) * k[7]))))
4654
end
4755

@@ -63,7 +71,6 @@ end
6371
@unpack num_stages, index = cache
6472
@unpack c = cache.tabs[index]
6573
Θdt = 1 - Θ
66-
@show k
6774
tmp = k[num_stages + 1] - k[num_stages + 2] * (Θdt + c[1] - 1)
6875
j = num_stages - 2
6976
while j > 0

0 commit comments

Comments
 (0)