Skip to content

Commit f6dbee0

Browse files
fix some errors
1 parent d11faed commit f6dbee0

File tree

3 files changed

+54
-41
lines changed

3 files changed

+54
-41
lines changed

src/caches/firk_caches.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ dw1::uType
325325
ubuff::uType
326326
dw23::cuType
327327
dw45::cuType
328-
cubuff::cuType
328+
cubuff1::cuType
329+
cubuff2::cuType
329330
cont1::uType
330331
cont2::uType
331332
cont3::uType
@@ -392,8 +393,10 @@ dw23 = similar(u, Complex{eltype(u)})
392393
dw45 = similar(u, Complex{eltype(u)})
393394
recursivefill!(dw23, false)
394395
recursivefill!(dw45, false)
395-
cubuff = similar(u, Complex{eltype(u)})
396-
recursivefill!(cubuff, false)
396+
cubuff1 = similar(u, Complex{eltype(u)})
397+
cubuff2 = similar(u, Complex{eltype(u)})
398+
recursivefill!(cubuff1, false)
399+
recursivefill!(cubuff2, false)
397400
cont1 = zero(u)
398401
cont2 = zero(u)
399402
cont3 = zero(u)
@@ -432,12 +435,12 @@ linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
432435
assumptions = LinearSolve.OperatorAssumptions(true))
433436
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
434437
#Pr = Diagonal(_vec(weight)))
435-
linprob = LinearProblem(W2, _vec(cubuff); u0 = _vec(dw23))
438+
linprob = LinearProblem(W2, _vec(cubuff1); u0 = _vec(dw23))
436439
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
437440
assumptions = LinearSolve.OperatorAssumptions(true))
438441
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
439442
#Pr = Diagonal(_vec(weight)))
440-
linprob = LinearProblem(W3, _vec(cubuff); u0 = _vec(dw45))
443+
linprob = LinearProblem(W3, _vec(cubuff2); u0 = _vec(dw45))
441444
linsolve3 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
442445
assumptions = LinearSolve.OperatorAssumptions(true))
443446
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
@@ -449,7 +452,7 @@ atol = reltol isa Number ? reltol : zero(reltol)
449452

450453
RadauIIA7Cache(u, uprev,
451454
z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
452-
dw1, ubuff, dw23, dw45, cubuff, cont1, cont2, cont3, cont4,
455+
dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4,
453456
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
454457
J, W1, W2, W3,
455458
uf, tab, κ, one(uToltype), 10000,

src/integrators/integrator_interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ end
111111
get_tmp_cache(integrator::ODEIntegrator, integrator.alg, integrator.cache)
112112
end
113113
# avoid method ambiguity
114-
for typ in (OrdinaryDiffEqAlgorithm, Union{RadauIIA3, RadauIIA5},
114+
for typ in (OrdinaryDiffEqAlgorithm, Union{RadauIIA3, RadauIIA5, RadauIIA7},
115115
OrdinaryDiffEqNewtonAdaptiveAlgorithm,
116116
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm,
117117
Union{SSPRK22, SSPRK33, SSPRK53_2N1, SSPRK53_2N2, SSPRK43, SSPRK432, SSPRK932})
@@ -126,7 +126,7 @@ end
126126
cache::OrdinaryDiffEqMutableCache)
127127
(cache.tmp,)
128128
end
129-
@inline function DiffEqBase.get_tmp_cache(integrator, alg::Union{RadauIIA3, RadauIIA5},
129+
@inline function DiffEqBase.get_tmp_cache(integrator, alg::Union{RadauIIA3, RadauIIA5, RadauIIA7},
130130
cache::OrdinaryDiffEqMutableCache)
131131
(cache.tmp, cache.atmp)
132132
end

src/perform_step/firk_perform_step.jl

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ function initialize!(integrator, cache::RadauIIA7Cache)
108108
integrator.k[2] = integrator.fsallast
109109
integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t)
110110
integrator.stats.nf += 1
111+
if integrator.opts.adaptive
112+
@unpack abstol, reltol = integrator.opts
113+
if reltol isa Number
114+
cache.rtol = reltol^(2 / 3) / 10
115+
cache.atol = cache.rtol * (abstol / reltol)
116+
else
117+
@.. broadcast=false cache.rtol=reltol^(2 / 3) / 10
118+
@.. broadcast=false cache.atol=cache.rtol * (abstol / reltol)
119+
end
120+
end
111121
nothing
112122
end
113123

@@ -459,7 +469,6 @@ end
459469
atmp2 = calculate_residuals(dw2, uprev, u, atol, rtol, internalnorm, t)
460470
atmp3 = calculate_residuals(dw3, uprev, u, atol, rtol, internalnorm, t)
461471
ndw = internalnorm(atmp1, t) + internalnorm(atmp2, t) + internalnorm(atmp3, t)
462-
463472
# check divergence (not in initial step)
464473
if iter > 1
465474
θ = ndw / ndwprev
@@ -789,7 +798,7 @@ end
789798
mass_matrix = integrator.f.mass_matrix
790799

791800
# precalculations rtol pow is (num stages + 1)/(2*num stages)
792-
rtol = @.. broadcast=false reltol#^(5/8)/10
801+
rtol = @.. broadcast=false reltol^(5/8)/10
793802
atol = @.. broadcast=false rtol*(abstol / reltol)
794803
c1m1 = c1 - 1
795804
c2m1 = c2 - 1
@@ -859,7 +868,7 @@ end
859868
ff2 = f(uprev + z2, p, t + c2 * dt)
860869
ff3 = f(uprev + z3, p, t + c3 * dt)
861870
ff4 = f(uprev + z4, p, t + c4 * dt)
862-
ff5 = f(uprev + z4, p, t + dt) # c5 = 1
871+
ff5 = f(uprev + z5, p, t + dt) # c5 = 1
863872
integrator.stats.nf += 5
864873

865874
fw1 = @.. broadcast=false TI11*ff1 + TI12*ff2 + TI13*ff3 + TI14*ff4 + TI15*ff5
@@ -904,7 +913,6 @@ end
904913
atmp4 = calculate_residuals(dw4, uprev, u, atol, rtol, internalnorm, t)
905914
atmp5 = calculate_residuals(dw5, uprev, u, atol, rtol, internalnorm, t)
906915
ndw = internalnorm(atmp1, t) + internalnorm(atmp2, t) + internalnorm(atmp3, t) + internalnorm(atmp4, t) + internalnorm(atmp5, t)
907-
908916
# check divergence (not in initial step)
909917
if iter > 1
910918
θ = ndw / ndwprev
@@ -939,6 +947,7 @@ end
939947
break
940948
end
941949
end
950+
942951
if fail_convergence
943952
integrator.force_stepfail = true
944953
integrator.stats.nnonlinconvfail += 1
@@ -948,15 +957,13 @@ end
948957
cache.iter = iter
949958

950959
u = @.. broadcast=false uprev + z5
951-
#=
960+
952961
if adaptive
953-
e1dt, e2dt, e3dt = e1 / dt, e2 / dt, e3 / dt
954-
tmp = @.. broadcast=false e1dt*z1+e2dt*z2+e3dt*z3
962+
e1dt, e2dt, e3dt, e4dt, e5dt = e1 / dt, e2 / dt, e3 / dt, e4/dt, e5/dt
963+
tmp = @.. broadcast=false e1dt*z1+e2dt*z2+e3dt*z3+e4dt*z4 + e5dt*z5
955964
mass_matrix != I && (tmp = mass_matrix * tmp)
956965
utilde = @.. broadcast=false integrator.fsalfirst+tmp
957966
alg.smooth_est && (utilde = LU1 \ utilde; integrator.stats.nsolve += 1)
958-
# RadauIIA5 needs a transformed rtol and atol see
959-
# https://github.com/luchr/ODEInterface.jl/blob/0bd134a5a358c4bc13e0fb6a90e27e4ee79e0115/src/radau5.f#L399-L421
960967
atmp = calculate_residuals(utilde, uprev, u, atol, rtol, internalnorm, t)
961968
integrator.EEst = internalnorm(atmp, t)
962969

@@ -970,7 +977,7 @@ end
970977
integrator.EEst = internalnorm(atmp, t)
971978
end
972979
end
973-
=#
980+
974981
if integrator.EEst <= oneunit(integrator.EEst)
975982
cache.dtprev = dt
976983
if alg.extrapolant != :constant
@@ -1002,10 +1009,10 @@ end
10021009
@unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache.tab
10031010
@unpack κ, cont1, cont2, cont3, cont4 = cache
10041011
@unpack z1, z2, z3, z4, z5, w1, w2, w3, w4, w5 = cache
1005-
@unpack dw1, ubuff, dw23, dw45, cubuff = cache
1006-
@unpack k, k2, k3, k4, fw1, fw2, fw3, fw4, fw5 = cache
1012+
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
1013+
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
10071014
@unpack J, W1, W2, W3 = cache
1008-
tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
1015+
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
10091016
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
10101017
alg = unwrap_alg(integrator, true)
10111018
@unpack maxiters = alg
@@ -1023,7 +1030,7 @@ end
10231030
c2mc4 = c2 - c4
10241031
c3mc4 = c3 - c4
10251032

1026-
γdt, αdt, βdt = γ / dt, α / dt, β / dt
1033+
γdt, α1dt, β1dt, α2dt, β2dt= γ / dt, α1 / dt, β1 / dt, α2 / dt, β2 / dt
10271034
(new_jac = do_newJ(integrator, alg, cache, repeat_step)) &&
10281035
(calc_J!(J, integrator, cache); cache.W_γdt = dt)
10291036
if (new_W = do_newW(integrator, alg, new_jac, cache.W_γdt))
@@ -1130,7 +1137,7 @@ end
11301137
Mw5 = z5
11311138
end
11321139

1133-
@.. broadcast=false ubuff=fw1 - γdt * Mw1
1140+
@.. broadcast=false ubuff = fw1 - γdt * Mw1
11341141
needfactor = iter == 1 && new_W
11351142

11361143
linsolve1 = cache.linsolve1
@@ -1145,26 +1152,29 @@ end
11451152

11461153
cache.linsolve1 = linres1.cache
11471154

1148-
@.. broadcast=false cubuff=complex(fw2 - αdt * Mw2 + βdt * Mw3,
1149-
fw3 - βdt * Mw2 - αdt * Mw3)
1155+
@.. broadcast=false cubuff1 = complex(fw2 - α1dt * Mw2 + β1dt * Mw3, fw3 - β1dt * Mw2 - α1dt * Mw3)
11501156

11511157
linsolve2 = cache.linsolve2
11521158

11531159
if needfactor
1154-
linres2 = dolinsolve(integrator, linsolve2; A = W2, b = _vec(cubuff),
1160+
linres2 = dolinsolve(integrator, linsolve2; A = W2, b = _vec(cubuff1),
11551161
linu = _vec(dw23))
11561162
else
1157-
linres2 = dolinsolve(integrator, linsolve2; A = nothing, b = _vec(cubuff),
1163+
linres2 = dolinsolve(integrator, linsolve2; A = nothing, b = _vec(cubuff1),
11581164
linu = _vec(dw23))
11591165
end
11601166

11611167
cache.linsolve2 = linres2.cache
11621168

1169+
@.. broadcast=false cubuff2 = complex(fw4 - α2dt * Mw4 + β2dt * Mw5, fw5 - β2dt * Mw4 - α2dt * Mw5)
1170+
1171+
linsolve3 = cache.linsolve3
1172+
11631173
if needfactor
1164-
linres3 = dolinsolve(integrator, linsolve2; A = W3, b = _vec(cubuff),
1174+
linres3 = dolinsolve(integrator, linsolve3; A = W3, b = _vec(cubuff2),
11651175
linu = _vec(dw45))
11661176
else
1167-
linres3 = dolinsolve(integrator, linsolve2; A = nothing, b = _vec(cubuff),
1177+
linres3 = dolinsolve(integrator, linsolve3; A = nothing, b = _vec(cubuff2),
11681178
linu = _vec(dw45))
11691179
end
11701180

@@ -1242,8 +1252,8 @@ end
12421252
#=
12431253
if adaptive
12441254
utilde = w2
1245-
e1dt, e2dt, e3dt = e1 / dt, e2 / dt, e3 / dt
1246-
@.. broadcast=false tmp=e1dt * z1 + e2dt * z2 + e3dt * z3
1255+
e1dt, e2dt, e3dt, e4dt, e5dt = e1 / dt, e2 / dt, e3 / dt, e4 / dt, e5 / dt
1256+
@.. broadcast=false tmp=e1dt * z1 + e2dt * z2 + e3dt * z3 + e4dt * z4 + e5dt * z5
12471257
mass_matrix != I && (mul!(w1, mass_matrix, tmp); copyto!(tmp, w1))
12481258
@.. broadcast=false ubuff=integrator.fsalfirst + tmp
12491259
@@ -1282,16 +1292,16 @@ end
12821292
if integrator.EEst <= oneunit(integrator.EEst)
12831293
cache.dtprev = dt
12841294
if alg.extrapolant != :constant
1285-
@.. broadcast=false cache.cont1 = (z4 - z5) / c4m1
1286-
@.. broadcast=false tmp1 = (z3 - z4) / c3mc4
1287-
@.. broadcast=false cache.cont2 = (tmp1 - cache.cont1) / c3m1
1288-
@.. broadcast=false tmp2 = (z2 - z3) / c2mc3
1289-
@.. broadcast=false tmp3 = (tmp2 - tmp) / c2mc4
1290-
@.. broadcast=false cache.cont3 = (tmp3 - cache.cont2) / c2m1
1291-
@.. broadcast=false tmp4 = (z1 - z2) / c1mc2
1292-
@.. broadcast=false tmp5 = (tmp4 - tmp2) / c1mc3
1293-
@.. broadcast=false tmp6 = (tmp5 - tmp3) / c1mc4
1294-
@.. broadcast=false cache.cont4 = (tmp6 - cache.cont3) / c1m1
1295+
@.. cache.cont1 = (z4 - z5) / c4m1
1296+
@.. tmp1 = (z3 - z4) / c3mc4
1297+
@.. cache.cont2 = (tmp1 - cache.cont1) / c3m1
1298+
@.. tmp2 = (z2 - z3) / c2mc3
1299+
@.. tmp3 = (tmp2 - tmp) / c2mc4
1300+
@.. cache.cont3 = (tmp3 - cache.cont2) / c2m1
1301+
@.. tmp4 = (z1 - z2) / c1mc2
1302+
@.. tmp5 = (tmp4 - tmp2) / c1mc3
1303+
@.. tmp6 = (tmp5 - tmp3) / c1mc4
1304+
@.. cache.cont4 = (tmp6 - cache.cont3) / c1m1
12951305
end
12961306
end
12971307

0 commit comments

Comments
 (0)