Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ OrdinaryDiffEqQPRK = "1"
OrdinaryDiffEqRKN = "1"
OrdinaryDiffEqRosenbrock = "1"
OrdinaryDiffEqSDIRK = "1"
OrdinaryDiffEqStabilizedIRK = "1"
OrdinaryDiffEqSSPRK = "1"
OrdinaryDiffEqStabilizedIRK = "1"
OrdinaryDiffEqStabilizedRK = "1"
OrdinaryDiffEqSymplecticRK = "1"
OrdinaryDiffEqTsit5 = "1"
Expand Down
7 changes: 5 additions & 2 deletions lib/OrdinaryDiffEqFIRK/src/alg_utils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA9}) = 8
qmax_default(alg::Union{RadauIIA3, RadauIIA5, RadauIIA9, AdaptiveRadau}) = 8

alg_order(alg::RadauIIA3) = 3
alg_order(alg::RadauIIA5) = 5
alg_order(alg::RadauIIA9) = 9
alg_order(alg::AdaptiveRadau) = 5
alg_order(alg::AdaptiveRadau) = 5 #dummy value

isfirk(alg::RadauIIA3) = true
isfirk(alg::RadauIIA5) = true
Expand All @@ -13,3 +13,6 @@ isfirk(alg::AdaptiveRadau) = true
alg_adaptive_order(alg::RadauIIA3) = 1
alg_adaptive_order(alg::RadauIIA5) = 3
alg_adaptive_order(alg::RadauIIA9) = 5

get_current_alg_order(alg::AdaptiveRadau, cache) = cache.num_stages * 2 - 1
get_current_adaptive_order(alg::AdaptiveRadau, cache) = cache.num_stages
7 changes: 4 additions & 3 deletions lib/OrdinaryDiffEqFIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,13 @@ struct AdaptiveRadau{CS, AD, F, P, FDT, ST, CJ, Tol, C1, C2, StepLimiter} <:
new_W_γdt_cutoff::C2
controller::Symbol
step_limiter!::StepLimiter
num_stages::Int
min_stages::Int
max_stages::Int
end

function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, num_stages = 3,
diff_type = Val{:forward}, min_stages = 3, max_stages = 7,
linsolve = nothing, precs = DEFAULT_PRECS,
extrapolant = :dense, fast_convergence_cutoff = 1 // 5,
new_W_γdt_cutoff = 1 // 5,
Expand All @@ -186,6 +187,6 @@ function AdaptiveRadau(; chunk_size = Val{0}(), autodiff = Val{true}(),
fast_convergence_cutoff,
new_W_γdt_cutoff,
controller,
step_limiter!, num_stages)
step_limiter!, min_stages, max_stages)
end

60 changes: 59 additions & 1 deletion lib/OrdinaryDiffEqFIRK/src/controllers.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
@inline function stepsize_controller!(integrator, controller::PredictiveController, alg)
@unpack qmin, qmax, gamma = integrator.opts
EEst = DiffEqBase.value(integrator.EEst)

if iszero(EEst)
q = inv(qmax)
else
Expand All @@ -26,6 +25,7 @@ end

function step_accept_controller!(integrator, controller::PredictiveController, alg, q)
@unpack qmin, qmax, gamma, qsteady_min, qsteady_max = integrator.opts

EEst = DiffEqBase.value(integrator.EEst)

if integrator.success_iter > 0
Expand All @@ -42,10 +42,68 @@ function step_accept_controller!(integrator, controller::PredictiveController, a
end
integrator.dtacc = integrator.dt
integrator.erracc = max(1e-2, EEst)

return integrator.dt / qacc
end


function step_accept_controller!(integrator, controller::PredictiveController, alg::AdaptiveRadau, q)
@unpack qmin, qmax, gamma, qsteady_min, qsteady_max = integrator.opts
@unpack cache = integrator
@unpack num_stages, step, iter, hist_iter = cache

EEst = DiffEqBase.value(integrator.EEst)

if integrator.success_iter > 0
expo = 1 / (get_current_adaptive_order(alg, integrator.cache) + 1)
qgus = (integrator.dtacc / integrator.dt) *
DiffEqBase.fastpow((EEst^2) / integrator.erracc, expo)
qgus = max(inv(qmax), min(inv(qmin), qgus / gamma))
qacc = max(q, qgus)
else
qacc = q
end
if qsteady_min <= qacc <= qsteady_max
qacc = one(qacc)
end
integrator.dtacc = integrator.dt
integrator.erracc = max(1e-2, EEst)
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
if (step > 10)
if (hist_iter < 2.6 && num_stages < alg.max_stages)
cache.num_stages += 2
cache.step = 1
cache.hist_iter = iter
elseif ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
end
end
return integrator.dt / qacc
end

function step_reject_controller!(integrator, controller::PredictiveController, alg)
@unpack dt, success_iter, qold = integrator
integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold
end

function step_reject_controller!(integrator, controller::PredictiveController, alg::AdaptiveRadau)
@unpack dt, success_iter, qold = integrator
@unpack cache = integrator
@unpack num_stages, step, iter, hist_iter = cache
integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold
cache.step = step + 1
hist_iter = hist_iter * 0.8 + iter * 0.2
cache.hist_iter = hist_iter
if (step > 10)
if ((hist_iter > 8 || cache.status == VerySlowConvergence || cache.status == Divergence) && num_stages > alg.min_stages)
cache.num_stages -= 2
cache.step = 1
cache.hist_iter = iter
end
end
end

116 changes: 59 additions & 57 deletions lib/OrdinaryDiffEqFIRK/src/firk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ end
mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
OrdinaryDiffEqConstantCache
uf::F
tab::Tab
tabs::Vector{Tab}
κ::Tol
ηold::Tol
iter::Int
Expand All @@ -486,6 +486,9 @@ mutable struct AdaptiveRadauConstantCache{F, Tab, Tol, Dt, U, JType} <:
W_γdt::Dt
status::NLStatus
J::JType
num_stages::Int
step::Int
hist_iter::Float64
end

function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -494,34 +497,28 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UDerivativeWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.num_stages

if (num_stages == 3)
tab = BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits))
elseif (num_stages == 5)
tab = BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))
elseif (num_stages == 7)
tab = BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))
elseif iseven(num_stages) || num_stages <3
error("num_stages must be odd and 3 or greater")
else
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)
num_stages = alg.min_stages
max = alg.max_stages
tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]

i = 9
while i <= alg.max_stages
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
i += 2
end

cont = Vector{typeof(u)}(undef, num_stages)
for i in 1: num_stages
cont = Vector{typeof(u)}(undef, max)
for i in 1: max
cont[i] = zero(u)
end

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)
J = false .* _vec(rate_prototype) .* _vec(rate_prototype)'

AdaptiveRadauConstantCache(uf, tab, κ, one(uToltype), 10000, cont, dt, dt,
Convergence, J)
AdaptiveRadauConstantCache(uf, tabs, κ, one(uToltype), 10000, cont, dt, dt,
Convergence, J, num_stages, 1, 0.0)
end

mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType, JType, W1Type, W2Type,
UF, JC, F1, F2, Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
UF, JC, F1, F2, #=F3,=# Tab, Tol, Dt, rTol, aTol, StepLimiter} <:
FIRKMutableCache
u::uType
uprev::uType
Expand All @@ -544,7 +541,7 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
W1::W1Type #real
W2::Vector{W2Type} #complex
uf::UF
tab::Tab
tabs::Vector{Tab}
κ::Tol
ηold::Tol
iter::Int
Expand All @@ -553,12 +550,16 @@ mutable struct AdaptiveRadauCache{uType, cuType, tType, uNoUnitsType, rateType,
jac_config::JC
linsolve1::F1 #real
linsolve2::Vector{F2} #complex
#linres2::Vector{F3}
rtol::rTol
atol::aTol
dtprev::Dt
W_γdt::Dt
status::NLStatus
step_limiter!::StepLimiter
num_stages::Int
step::Int
hist_iter::Float64
end

function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits},
Expand All @@ -567,62 +568,56 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
uf = UJacobianWrapper(f, t, p)
uToltype = constvalue(uBottomEltypeNoUnits)
num_stages = alg.num_stages

if (num_stages == 3)
tab = BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits))
elseif (num_stages == 5)
tab = BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits))
elseif (num_stages == 7)
tab = BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))
elseif iseven(num_stages) || num_stages < 3
error("num_stages must be odd and 3 or greater")
else
tab = adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), num_stages)

min = alg.min_stages
max = alg.max_stages

num_stages = min

tabs = [BigRadauIIA5Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA9Tableau(uToltype, constvalue(tTypeNoUnits)), BigRadauIIA13Tableau(uToltype, constvalue(tTypeNoUnits))]
i = 9
while i <= max
push!(tabs, adaptiveRadauTableau(uToltype, constvalue(tTypeNoUnits), i))
i += 2
end

κ = alg.κ !== nothing ? convert(uToltype, alg.κ) : convert(uToltype, 1 // 100)

z = Vector{typeof(u)}(undef, num_stages)
w = Vector{typeof(u)}(undef, num_stages)
for i in 1 : num_stages
z = Vector{typeof(u)}(undef, max)
w = Vector{typeof(u)}(undef, max)
for i in 1 : max
z[i] = w[i] = zero(u)
end

c_prime = Vector{typeof(t)}(undef, num_stages) #time stepping
c_prime = Vector{typeof(t)}(undef, max) #time stepping

dw1 = zero(u)
ubuff = zero(u)
dw2 = [similar(u, Complex{eltype(u)}) for _ in 1 : (num_stages - 1) ÷ 2]
dw2 = [similar(u, Complex{eltype(u)}) for _ in 1 : (max - 1) ÷ 2]
recursivefill!.(dw2, false)
cubuff = [similar(u, Complex{eltype(u)}) for _ in 1 : (num_stages - 1) ÷ 2]
cubuff = [similar(u, Complex{eltype(u)}) for _ in 1 : (max - 1) ÷ 2]
recursivefill!.(cubuff, false)
dw = Vector{typeof(u)}(undef, num_stages - 1)
dw = [zero(u) for i in 1 : max]

cont = Vector{typeof(u)}(undef, num_stages)
for i in 1 : num_stages
cont[i] = zero(u)
end
cont = [zero(u) for i in 1:max]

derivatives = Matrix{typeof(u)}(undef, num_stages, num_stages)
for i in 1 : num_stages, j in 1 : num_stages
derivatives = Matrix{typeof(u)}(undef, max, max)
for i in 1 : max, j in 1 : max
derivatives[i, j] = zero(u)
end

fsalfirst = zero(rate_prototype)
fw = Vector{typeof(rate_prototype)}(undef, num_stages)
ks = Vector{typeof(rate_prototype)}(undef, num_stages)
for i in 1: num_stages
ks[i] = fw[i] = zero(rate_prototype)
end
fw = [zero(rate_prototype) for i in 1 : max]
ks = [zero(rate_prototype) for i in 1 : max]

k = ks[1]

J, W1 = build_J_W(alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val(true))
if J isa AbstractSciMLOperator
error("Non-concrete Jacobian not yet supported by AdaptiveRadau.")
end

W2 = [similar(J, Complex{eltype(W1)}) for _ in 1 : (num_stages - 1) ÷ 2]
W2 = [similar(J, Complex{eltype(W1)}) for _ in 1 : (max - 1) ÷ 2]
recursivefill!.(W2, false)

du1 = zero(rate_prototype)
Expand All @@ -640,18 +635,25 @@ function alg_cache(alg::AdaptiveRadau, u, rate_prototype, ::Type{uEltypeNoUnits}

linsolve2 = [
init(LinearProblem(W2[i], _vec(cubuff[i]); u0 = _vec(dw2[i])), alg.linsolve, alias_A = true, alias_b = true,
assumptions = LinearSolve.OperatorAssumptions(true)) for i in 1 : (num_stages - 1) ÷ 2]

assumptions = LinearSolve.OperatorAssumptions(true)) for i in 1 : (max - 1) ÷ 2]
#=
linres_tmp = dolinsolve(nothing, linsolve2[1]; A = W2[1], b = _vec(cubuff[1]), linu = _vec(dw2[1]))
linres2 = Vector{typeof(linres_tmp)}(undef , (max - 1) ÷ 2)
linres2[1] = linres_tmp
for i in 2 : (num_stages - 1) ÷ 2
linres2[i] = dolinsolve(nothing, linsolve2[1]; A = W2[1], b = _vec(cubuff[i]), linu = _vec(dw2[i]))
end
=#
rtol = reltol isa Number ? reltol : zero(reltol)
atol = reltol isa Number ? reltol : zero(reltol)

AdaptiveRadauCache(u, uprev,
z, w, c_prime, dw1, ubuff, dw2, cubuff, dw, cont, derivatives,
du1, fsalfirst, ks, k, fw,
J, W1, W2,
uf, tab, κ, one(uToltype), 10000, tmp,
uf, tabs, κ, one(uToltype), 10000, tmp,
atmp, jac_config,
linsolve1, linsolve2, rtol, atol, dt, dt,
Convergence, alg.step_limiter!)
linsolve1, linsolve2, #=linres2,=# rtol, atol, dt, dt,
Convergence, alg.step_limiter!, num_stages, 1, 0.0)
end

Loading
Loading