Skip to content

Commit ec86815

Browse files
update in-place method
1 parent f6dbee0 commit ec86815

File tree

2 files changed

+113
-102
lines changed

2 files changed

+113
-102
lines changed

src/caches/firk_caches.jl

Lines changed: 100 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,11 @@ tab::Tab
353353
ηold::Tol
354354
iter::Int
355355
tmp::uType
356+
tmp2::uType
357+
tmp3::uType
358+
tmp4::uType
359+
tmp5::uType
360+
tmp6::uType
356361
atmp::uNoUnitsType
357362
jac_config::JC
358363
linsolve1::F1
@@ -368,94 +373,99 @@ end
368373
TruncatedStacktraces.@truncate_stacktrace RadauIIA7Cache 1
369374

370375
function alg_cache(alg::RadauIIA7, u, rate_prototype, ::Type{uEltypeNoUnits},
371-
::Type{uBottomEltypeNoUnits},
372-
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
373-
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
374-
uf = UJacobianWrapper(f, t, p)
375-
uToltype = constvalue(uBottomEltypeNoUnits)
376-
tab = RadauIIA7Tableau(uToltype, constvalue(tTypeNoUnits))
377-
378-
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
379-
380-
z1 = zero(u)
381-
z2 = zero(u)
382-
z3 = zero(u)
383-
z4 = zero(u)
384-
z5 = zero(u)
385-
w1 = zero(u)
386-
w2 = zero(u)
387-
w3 = zero(u)
388-
w4 = zero(u)
389-
w5 = zero(u)
390-
dw1 = zero(u)
391-
ubuff = zero(u)
392-
dw23 = similar(u, Complex{eltype(u)})
393-
dw45 = similar(u, Complex{eltype(u)})
394-
recursivefill!(dw23, false)
395-
recursivefill!(dw45, false)
396-
cubuff1 = similar(u, Complex{eltype(u)})
397-
cubuff2 = similar(u, Complex{eltype(u)})
398-
recursivefill!(cubuff1, false)
399-
recursivefill!(cubuff2, false)
400-
cont1 = zero(u)
401-
cont2 = zero(u)
402-
cont3 = zero(u)
403-
cont4 = zero(u)
404-
405-
fsalfirst = zero(rate_prototype)
406-
k = zero(rate_prototype)
407-
k2 = zero(rate_prototype)
408-
k3 = zero(rate_prototype)
409-
k4 = zero(rate_prototype)
410-
k5 = zero(rate_prototype)
411-
fw1 = zero(rate_prototype)
412-
fw2 = zero(rate_prototype)
413-
fw3 = zero(rate_prototype)
414-
fw4 = zero(rate_prototype)
415-
fw5 = zero(rate_prototype)
416-
417-
J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
418-
if J isa AbstractSciMLOperator
419-
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
420-
end
421-
W2 = similar(J, Complex{eltype(W1)})
422-
W3 = similar(J, Complex{eltype(W1)})
423-
recursivefill!(W2, false)
424-
recursivefill!(W3, false)
425-
426-
du1 = zero(rate_prototype)
427-
428-
tmp = zero(u)
429-
atmp = similar(u, uEltypeNoUnits)
430-
recursivefill!(atmp, false)
431-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
432-
433-
linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
434-
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
435-
assumptions = LinearSolve.OperatorAssumptions(true))
436-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
437-
#Pr = Diagonal(_vec(weight)))
438-
linprob = LinearProblem(W2, _vec(cubuff1); u0 = _vec(dw23))
439-
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
440-
assumptions = LinearSolve.OperatorAssumptions(true))
441-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
442-
#Pr = Diagonal(_vec(weight)))
443-
linprob = LinearProblem(W3, _vec(cubuff2); u0 = _vec(dw45))
444-
linsolve3 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
445-
assumptions = LinearSolve.OperatorAssumptions(true))
446-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
447-
#Pr = Diagonal(_vec(weight)))
448-
449-
450-
rtol = reltol isa Number ? reltol : zero(reltol)
451-
atol = reltol isa Number ? reltol : zero(reltol)
452-
453-
RadauIIA7Cache(u, uprev,
454-
z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
455-
dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4,
456-
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
457-
J, W1, W2, W3,
458-
uf, tab, κ, one(uToltype), 10000,
459-
tmp, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
460-
Convergence, alg.step_limiter!)
376+
::Type{uBottomEltypeNoUnits},
377+
::Type{tTypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck,
378+
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
379+
uf = UJacobianWrapper(f, t, p)
380+
uToltype = constvalue(uBottomEltypeNoUnits)
381+
tab = RadauIIA7Tableau(uToltype, constvalue(tTypeNoUnits))
382+
383+
κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
384+
385+
z1 = zero(u)
386+
z2 = zero(u)
387+
z3 = zero(u)
388+
z4 = zero(u)
389+
z5 = zero(u)
390+
w1 = zero(u)
391+
w2 = zero(u)
392+
w3 = zero(u)
393+
w4 = zero(u)
394+
w5 = zero(u)
395+
dw1 = zero(u)
396+
ubuff = zero(u)
397+
dw23 = similar(u, Complex{eltype(u)})
398+
dw45 = similar(u, Complex{eltype(u)})
399+
recursivefill!(dw23, false)
400+
recursivefill!(dw45, false)
401+
cubuff1 = similar(u, Complex{eltype(u)})
402+
cubuff2 = similar(u, Complex{eltype(u)})
403+
recursivefill!(cubuff1, false)
404+
recursivefill!(cubuff2, false)
405+
cont1 = zero(u)
406+
cont2 = zero(u)
407+
cont3 = zero(u)
408+
cont4 = zero(u)
409+
410+
fsalfirst = zero(rate_prototype)
411+
k = zero(rate_prototype)
412+
k2 = zero(rate_prototype)
413+
k3 = zero(rate_prototype)
414+
k4 = zero(rate_prototype)
415+
k5 = zero(rate_prototype)
416+
fw1 = zero(rate_prototype)
417+
fw2 = zero(rate_prototype)
418+
fw3 = zero(rate_prototype)
419+
fw4 = zero(rate_prototype)
420+
fw5 = zero(rate_prototype)
421+
422+
J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
423+
if J isa AbstractSciMLOperator
424+
error("Non-concrete Jacobian not yet supported by RadauIIA5.")
425+
end
426+
W2 = similar(J, Complex{eltype(W1)})
427+
W3 = similar(J, Complex{eltype(W1)})
428+
recursivefill!(W2, false)
429+
recursivefill!(W3, false)
430+
431+
du1 = zero(rate_prototype)
432+
433+
tmp = zero(u)
434+
tmp2 = zero(u)
435+
tmp3 = zero(u)
436+
tmp4 = zero(u)
437+
tmp5 = zero(u)
438+
tmp6 = zero(u)
439+
atmp = similar(u, uEltypeNoUnits)
440+
recursivefill!(atmp, false)
441+
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
442+
443+
linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
444+
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
445+
assumptions = LinearSolve.OperatorAssumptions(true))
446+
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
447+
#Pr = Diagonal(_vec(weight)))
448+
linprob = LinearProblem(W2, _vec(cubuff1); u0 = _vec(dw23))
449+
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
450+
assumptions = LinearSolve.OperatorAssumptions(true))
451+
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
452+
#Pr = Diagonal(_vec(weight)))
453+
linprob = LinearProblem(W3, _vec(cubuff2); u0 = _vec(dw45))
454+
linsolve3 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
455+
assumptions = LinearSolve.OperatorAssumptions(true))
456+
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
457+
#Pr = Diagonal(_vec(weight)))
458+
459+
460+
rtol = reltol isa Number ? reltol : zero(reltol)
461+
atol = reltol isa Number ? reltol : zero(reltol)
462+
463+
RadauIIA7Cache(u, uprev,
464+
z1, z2, z3, z4, z5, w1, w2, w3, w4, w5,
465+
dw1, ubuff, dw23, dw45, cubuff1, cubuff2, cont1, cont2, cont3, cont4,
466+
du1, fsalfirst, k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5,
467+
J, W1, W2, W3,
468+
uf, tab, κ, one(uToltype), 10000,
469+
tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, linsolve3, rtol, atol, dt, dt,
470+
Convergence, alg.step_limiter!)
461471
end

src/perform_step/firk_perform_step.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ end
789789
repeat_step = false)
790790
@unpack t, dt, uprev, u, f, p = integrator
791791
@unpack T11, T12, T13, T14, T15, T21, T22, T23, T24, T25, T31, T32, T33, T34, T35, T41, T42, T43, T44, T45, T51 = cache.tab #= T52 = 1, T53 = 0, T54 = 1, T55 = 0=#
792-
@unpack TI11, TI12, TI13, TI14, TI15, TI21, TI22, TI23, TI24, TI25, TI31, TI32, TI33,TI34, TI35, TI41, TI42, TI43, TI44, TI45, TI51, TI52, TI53, TI54, TI55 = cache.tab
792+
@unpack TI11, TI12, TI13, TI14, TI15, TI21, TI22, TI23, TI24, TI25, TI31, TI32, TI33, TI34, TI35, TI41, TI42, TI43, TI44, TI45, TI51, TI52, TI53, TI54, TI55 = cache.tab
793793
@unpack c1, c2, c3, c4, γ, α1, β1, α2, β2, e1, e2, e3, e4, e5 = cache.tab
794794
@unpack κ, cont1, cont2, cont3, cont4 = cache
795795
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
@@ -836,7 +836,7 @@ end
836836
cache.cont2 = map(zero, u)
837837
cache.cont3 = map(zero, u)
838838
cache.cont4 = map(zero, u)
839-
else # add cont4
839+
else
840840
c5′ = dt / cache.dtprev
841841
c1′ = c1 * c5′
842842
c2′ = c2 * c5′
@@ -850,7 +850,7 @@ end
850850
w1 = @.. broadcast=false TI11*z1 + TI12*z2 + TI13*z3 + TI14*z4 + TI15*z5
851851
w2 = @.. broadcast=false TI21*z1 + TI22*z2 + TI23*z3 + TI24*z4 + TI25*z5
852852
w3 = @.. broadcast=false TI31*z1 + TI32*z2 + TI33*z3 + TI34*z4 + TI35*z5
853-
w4 = @.. broadcast=false TI41*z1 + TI42*z2 + TI43*z3 + TI34*z4 + TI45*z5
853+
w4 = @.. broadcast=false TI41*z1 + TI42*z2 + TI43*z3 + TI44*z4 + TI45*z5
854854
w5 = @.. broadcast=false TI51*z1 + TI52*z2 + TI53*z3 + TI54*z4 + TI55*z5
855855
end
856856

@@ -913,7 +913,9 @@ end
913913
atmp4 = calculate_residuals(dw4, uprev, u, atol, rtol, internalnorm, t)
914914
atmp5 = calculate_residuals(dw5, uprev, u, atol, rtol, internalnorm, t)
915915
ndw = internalnorm(atmp1, t) + internalnorm(atmp2, t) + internalnorm(atmp3, t) + internalnorm(atmp4, t) + internalnorm(atmp5, t)
916+
916917
# check divergence (not in initial step)
918+
917919
if iter > 1
918920
θ = ndw / ndwprev
919921
(diverge = θ > 1) && (cache.status = Divergence)
@@ -993,7 +995,7 @@ end
993995
cache.cont4 = @.. (tmp6 - cache.cont3) / c1m1 #fourth derivative on [c1, 1]
994996
end
995997
end
996-
998+
997999
integrator.fsallast = f(u, p, t + dt)
9981000
integrator.stats.nf += 1
9991001
integrator.k[1] = integrator.fsalfirst
@@ -1012,7 +1014,7 @@ end
10121014
@unpack dw1, ubuff, dw23, dw45, cubuff1, cubuff2 = cache
10131015
@unpack k, k2, k3, k4, k5, fw1, fw2, fw3, fw4, fw5 = cache
10141016
@unpack J, W1, W2, W3 = cache
1015-
@unpack tmp, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
1017+
@unpack tmp, tmp2, tmp3, tmp4, tmp5, tmp6, atmp, jac_config, linsolve1, linsolve2, rtol, atol, step_limiter! = cache
10161018
@unpack internalnorm, abstol, reltol, adaptive = integrator.opts
10171019
alg = unwrap_alg(integrator, true)
10181020
@unpack maxiters = alg
@@ -1074,7 +1076,7 @@ end
10741076
w1 = @.. broadcast=false TI11*z1 + TI12*z2 + TI13*z3 + TI14*z4 + TI15*z5
10751077
w2 = @.. broadcast=false TI21*z1 + TI22*z2 + TI23*z3 + TI24*z4 + TI25*z5
10761078
w3 = @.. broadcast=false TI31*z1 + TI32*z2 + TI33*z3 + TI34*z4 + TI35*z5
1077-
w4 = @.. broadcast=false TI41*z1 + TI42*z2 + TI43*z3 + TI34*z4 + TI45*z5
1079+
w4 = @.. broadcast=false TI41*z1 + TI42*z2 + TI43*z3 + TI44*z4 + TI45*z5
10781080
w5 = @.. broadcast=false TI51*z1 + TI52*z2 + TI53*z3 + TI54*z4 + TI55*z5
10791081
end
10801082

@@ -1249,7 +1251,8 @@ end
12491251
@.. broadcast=false u=uprev + z5
12501252

12511253
step_limiter!(u, integrator, p, t + dt)
1252-
#=
1254+
1255+
12531256
if adaptive
12541257
utilde = w2
12551258
e1dt, e2dt, e3dt, e4dt, e5dt = e1 / dt, e2 / dt, e3 / dt, e4 / dt, e5 / dt
@@ -1287,14 +1290,13 @@ end
12871290
integrator.EEst = internalnorm(atmp, t)
12881291
end
12891292
end
1290-
=#
1291-
1293+
12921294
if integrator.EEst <= oneunit(integrator.EEst)
12931295
cache.dtprev = dt
12941296
if alg.extrapolant != :constant
12951297
@.. cache.cont1 = (z4 - z5) / c4m1
1296-
@.. tmp1 = (z3 - z4) / c3mc4
1297-
@.. cache.cont2 = (tmp1 - cache.cont1) / c3m1
1298+
@.. tmp = (z3 - z4) / c3mc4
1299+
@.. cache.cont2 = (tmp - cache.cont1) / c3m1
12981300
@.. tmp2 = (z2 - z3) / c2mc3
12991301
@.. tmp3 = (tmp2 - tmp) / c2mc4
13001302
@.. cache.cont3 = (tmp3 - cache.cont2) / c2m1
@@ -1305,7 +1307,6 @@ end
13051307
end
13061308
end
13071309

1308-
13091310
f(fsallast, u, p, t + dt)
13101311
integrator.stats.nf += 1
13111312
return

0 commit comments

Comments
 (0)