Skip to content

Commit 2deed06

Browse files
Merge #143
143: Increase test coverage r=charleskawczynski a=charleskawczynski This PR: - Fixes some bugs, to get more code running - Un-comments some tests, to increase test coverage - Refactors, and fixes bugs in, the convergence tabulation Co-authored-by: Charles Kawczynski <[email protected]>
2 parents 8b3c6b2 + 0cf407b commit 2deed06

File tree

10 files changed

+112
-85
lines changed

10 files changed

+112
-85
lines changed

src/ClimaTimeSteppers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ include("nl_solvers/newtons_method.jl")
8888

8989

9090
n_stages_ntuple(::Type{<:NTuple{Nstages}}) where {Nstages} = Nstages
91+
n_stages_ntuple(::Type{<:SVector{Nstages}}) where {Nstages} = Nstages
9192

9293
# Include concrete implementations
9394
include("solvers/imex_ark_tableaus.jl")

src/solvers/ark.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function step_u!(int, cache::AdditiveRungeKuttaFullCache, f::DiffEqBase.SplitFun
117117

118118
Nstages = n_stages(cache)
119119
(; C, Aimpl, Aexpl, B) = cache.tableau
120-
(; U, R, L, W, linsolve) = cache
120+
(; U, R, L, W, linsolve!) = cache
121121
(; u, p, t, dt) = int
122122

123123
U = U
@@ -177,7 +177,7 @@ end
177177
function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}, f::DiffEqBase.ODEFunction) where {Nstages}
178178
(; C, Aimpl, Aexpl, B) = cache.tableau
179179
(; u, p, t, dt) = int
180-
(; U, R, L, W, linsolve) = cache
180+
(; U, R, L, W, linsolve!) = cache
181181
Uhat = R[end] # can be used as work array, as only used in last stage
182182

183183
fL! = f.jvp # linear part
@@ -303,13 +303,15 @@ If the keyword `paperversion=true` is used, the coefficients from the paper are
303303
used. Otherwise it uses coefficients that make the scheme (much) more stable but less
304304
accurate
305305
"""
306-
Base.@kwdef struct ARK2GiraldoKellyConstantinescu{L} <: AdditiveRungeKutta
306+
Base.@kwdef struct ARK2GiraldoKellyConstantinescu{L, PV} <: AdditiveRungeKutta
307307
linsolve::L
308-
paperversion::Bool = false
309308
end
309+
ARK2GiraldoKellyConstantinescu(linsolve; paperversion::Bool = false) =
310+
ARK2GiraldoKellyConstantinescu{typeof(linsolve), paperversion}(linsolve)
311+
paperversion(::ARK2GiraldoKellyConstantinescu{L, PV}) where {L, PV} = PV
310312

311313
function tableau(ark::ARK2GiraldoKellyConstantinescu, RT)
312-
a32 = RT(ark.paperversion ? (3 + 2 * sqrt(2)) / 6 : 1 // 2)
314+
a32 = RT(paperversion(ark) ? (3 + 2 * sqrt(2)) / 6 : 1 // 2)
313315
RKA_explicit = @SArray [
314316
RT(0) RT(0) RT(0)
315317
RT(2 - sqrt(2)) RT(0) RT(0)

src/solvers/lsrk.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,15 @@ function init_inner(prob, outercache::LowStorageRungeKutta2NIncCache, dt)
6262
end
6363
function update_inner!(innerinteg, outercache::LowStorageRungeKutta2NIncCache, f_slow, u, p, t, dt, stage)
6464

65-
(; C, A, B) = cache.tableau
65+
(; C, A, B) = outercache.tableau
6666
f_offset = innerinteg.sol.prob.f
67-
tab = outercache.tableau
6867
N = n_stages(outercache)
6968

7069
τ0 = t + C[stage] * dt
7170
τ1 = stage == N ? t + dt : t + C[stage + 1] * dt
7271
f_offset.α = τ0
7372
innerinteg.t = zero(τ0)
74-
innerinteg.tstop = τ1 - τ0
73+
DiffEqBase.add_tstop!(innerinteg, τ1 - τ0) # TODO: verify correctness
7574

7675
# du .= f(u, p, t + C[stage]*dt) .+ A[stage] .* du
7776
f_slow(f_offset.x, u, p, τ0, 1, A[stage])

src/solvers/mis.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function update_inner!(innerinteg, outercache::MultirateInfinitesimalStepCache,
102102
F,
103103
innerinteg.u,
104104
f_offset.x,
105-
tab,
105+
outercache.tableau, # TODO: verify correctness
106106
i,
107107
N,
108108
dt;
@@ -117,7 +117,7 @@ function update_inner!(innerinteg, outercache::MultirateInfinitesimalStepCache,
117117
f_offset.β = (c[i] - c̃[i]) / d[i]
118118

119119
innerinteg.t = zero(t)
120-
innerinteg.tstop = d[i] * dt
120+
DiffEqBase.add_tstop!(innerinteg, d[i] * dt) # TODO: verify correctness
121121
end
122122

123123
@kernel function mis_update!(u, ΔU, F, innerinteg_u, f_offset_x, tab, i, N, dt)

src/solvers/multirate.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,10 @@ function step_u!(int, cache::MultirateCache)
6161

6262
# TODO: make this more generic
6363
# there are 2 strategies we can use here:
64-
# a. use same fast_dt for all slow stages, use `adjustfinal=true`
64+
# a. use same fast_dt for all slow stages
6565
# - problems for ARK (e.g. requires expensive LU factorization)
6666
# b. use different fast_dt, cache expensive ops
6767

68-
innerinteg.adjustfinal = true
6968
DiffEqBase.solve!(innerinteg)
7069
innerinteg.dt = fast_dt # reset
7170
end

src/solvers/wickerskamarock.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ function update_inner!(innerinteg, outercache::WickerSkamarockRungeKuttaCache, f
5454
end
5555

5656
innerinteg.t = t
57-
innerinteg.tstop = i == N ? t + dt : t + c[i + 1] * dt
57+
t_star = i == N ? t + dt : t + c[i + 1] * dt
58+
DiffEqBase.add_tstop!(innerinteg, t_star) # TODO: verify correctness
5859
end
5960

6061

test/convergence.jl

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ end
4545

4646
#=
4747
using Revise; include("test/convergence.jl")
48-
results = tabulate_convergence_orders();
48+
results = tabulate_convergence_orders_imex_ark();
4949
=#
50-
function tabulate_convergence_orders()
50+
function tabulate_convergence_orders_imex_ark()
5151
tabs = [
5252
ARS111,
5353
ARS121,
@@ -76,54 +76,78 @@ function tabulate_convergence_orders()
7676
tabs = map(t -> t(), tabs)
7777
test_cases = all_test_cases(Float64)
7878
results = convergence_order_results(tabs, test_cases)
79-
tabulate_convergence_orders(test_cases, tabs, results)
79+
algs = algorithm.(tabs)
80+
prob_names = map(t -> t.test_name, test_cases)
81+
expected_orders = ODE.alg_order.(tabs)
82+
tabulate_convergence_orders(prob_names, algs, results, expected_orders; tabs)
8083
return results
8184
end
82-
tabulate_convergence_orders()
85+
tabulate_convergence_orders_imex_ark()
8386

84-
#=
85-
if ArrayType == Array
86-
for (prob, sol) in [
87-
imex_autonomous_prob => imex_autonomous_sol,
88-
#imex_nonautonomous_prob => imex_nonautonomous_sol,
89-
]
90-
# IMEX
91-
@test convergence_order(prob, sol, ARK1ForwardBackwardEuler(linsolve=DirectSolver), dts) ≈ 1 atol=0.1
92-
@test convergence_order(prob, sol, ARK2ImplicitExplicitMidpoint(linsolve=DirectSolver), dts) ≈ 2 atol=0.05
93-
@test convergence_order(prob, sol, ARK2GiraldoKellyConstantinescu(linsolve=DirectSolver), dts) ≈ 2 atol=0.05
94-
@test convergence_order(prob, sol, ARK2GiraldoKellyConstantinescu(linsolve=DirectSolver; paperversion=true), dts) ≈ 2 atol=0.05
95-
@test convergence_order(prob, sol, ARK437L2SA1KennedyCarpenter(linsolve=DirectSolver), dts) ≈ 4 atol=0.05
96-
@test convergence_order(prob, sol, ARK548L2SA2KennedyCarpenter(linsolve=DirectSolver), dts) ≈ 5 atol=0.05
97-
end
87+
function tabulate_convergence_orders_ark()
88+
# IMEX
89+
co = Dict()
90+
names_probs_sols = [
91+
(:auto, imex_autonomous_prob(Array{Float64}), imex_autonomous_sol),
92+
(:nonauto, imex_nonautonomous_prob(Array{Float64}), imex_nonautonomous_sol),
93+
]
94+
algs_orders = [
95+
(ARK1ForwardBackwardEuler(DirectSolver), 1),
96+
(ARK2ImplicitExplicitMidpoint(DirectSolver), 2),
97+
(ARK2GiraldoKellyConstantinescu(DirectSolver), 2),
98+
(ARK2GiraldoKellyConstantinescu(DirectSolver; paperversion = true), 2),
99+
(ARK437L2SA1KennedyCarpenter(DirectSolver), 4),
100+
(ARK548L2SA2KennedyCarpenter(DirectSolver), 5),
101+
]
102+
dts = 0.5 .^ (4:7)
103+
for (name, prob, sol) in names_probs_sols
104+
for (alg, ord) in algs_orders
105+
co[name, typeof(alg)] = (ord, convergence_order(prob, sol, alg, dts))
106+
end
107+
end
108+
prob_names = first.(names_probs_sols)
109+
algs = first.(algs_orders)
110+
expected_orders = last.(algs_orders)
111+
tabulate_convergence_orders(prob_names, algs, co, expected_orders)
98112
end
99-
for (prob, sol) in [
100-
imex_autonomous_prob => imex_autonomous_sol,
101-
imex_nonautonomous_prob => imex_nonautonomous_sol,
102-
# kpr_multirate_prob => kpr_sol,
103-
]
104-
# Multirate
105-
@test convergence_order(prob, sol, Multirate(LSRK54CarpenterKennedy(),LSRK54CarpenterKennedy()), dts;
106-
fast_dt = 0.5^12, adjustfinal=true) ≈ 4 atol=0.05
107-
# MIS
108-
@test convergence_order(prob, sol, Multirate(LSRK54CarpenterKennedy(),MIS2()), dts;
109-
fast_dt = 0.5^12, adjustfinal=true) ≈ 2 atol=0.05
110-
@test convergence_order(prob, sol, Multirate(LSRK54CarpenterKennedy(),MIS3C()), dts;
111-
fast_dt = 0.5^12, adjustfinal=true) ≈ 2 atol=0.05
112-
@test convergence_order(prob, sol, Multirate(LSRK54CarpenterKennedy(),MIS4()), dts;
113-
fast_dt = 0.5^12, adjustfinal=true) ≈ 3 atol=0.05
114-
@test convergence_order(prob, sol, Multirate(LSRK54CarpenterKennedy(),MIS4a()), dts;
115-
fast_dt = 0.5^12, adjustfinal=true) ≈ 3 atol=0.05
116-
@test convergence_order(prob, sol, Multirate(LSRK54CarpenterKennedy(),TVDMISA()), dts;
117-
fast_dt = 0.5^12, adjustfinal=true) ≈ 2 atol=0.05
118-
@test convergence_order(prob, sol, Multirate(LSRK54CarpenterKennedy(),TVDMISB()), dts;
119-
fast_dt = 0.5^12, adjustfinal=true) ≈ 2 atol=0.05
120-
121-
# Wicker Skamarock
122-
@test convergence_order(prob, sol, Multirate(LSRK54CarpenterKennedy(),WSRK2()), dts;
123-
fast_dt = 0.5^12, adjustfinal=true) ≈ 2 atol=0.05
124-
@test convergence_order(prob, sol, Multirate(LSRK54CarpenterKennedy(),WSRK3()), dts;
125-
fast_dt = 0.5^12, adjustfinal=true) ≈ 2 atol=0.05
126113

114+
tabulate_convergence_orders_ark()
115+
116+
function tabulate_convergence_orders_multirate()
117+
118+
co = Dict()
119+
names_probs_sols = [
120+
(:imex_auto, imex_autonomous_prob(Array{Float64}), imex_autonomous_sol),
121+
(:imex_nonauto, imex_nonautonomous_prob(Array{Float64}), imex_nonautonomous_sol),
122+
# (:kpr_multirate, kpr_multirate_prob(), kpr_sol),
123+
]
124+
dts = 0.5 .^ (4:7)
125+
126+
algs_orders = [
127+
# Multirate
128+
(Multirate(LSRK54CarpenterKennedy(), LSRK54CarpenterKennedy()), 4),
129+
# MIS
130+
(Multirate(LSRK54CarpenterKennedy(), MIS2()), 2),
131+
(Multirate(LSRK54CarpenterKennedy(), MIS3C()), 2),
132+
(Multirate(LSRK54CarpenterKennedy(), MIS4()), 3),
133+
(Multirate(LSRK54CarpenterKennedy(), MIS4a()), 3),
134+
(Multirate(LSRK54CarpenterKennedy(), TVDMISA()), 2),
135+
(Multirate(LSRK54CarpenterKennedy(), TVDMISB()), 2),
136+
# Wicker Skamarock
137+
(Multirate(LSRK54CarpenterKennedy(), WSRK2()), 2),
138+
(Multirate(LSRK54CarpenterKennedy(), WSRK3()), 2),
139+
]
127140

141+
for (name, prob, sol) in names_probs_sols
142+
for (alg, ord) in algs_orders
143+
co[name, typeof(alg)] = (ord, convergence_order(prob, sol, alg, dts; fast_dt = 0.5^12))
144+
end
128145
end
129-
=#
146+
147+
prob_names = first.(names_probs_sols)
148+
algs = first.(algs_orders)
149+
expected_orders = last.(algs_orders)
150+
tabulate_convergence_orders(prob_names, algs, co, expected_orders)
151+
end
152+
153+
tabulate_convergence_orders_multirate()

test/convergence_orders.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
# TODO: is it better to use `first_order_tableau = Union{ARS111,ARS121}`? to
66
# reduce the number of methods?
7-
const first_order_tableau = [ARS111, ARS121]
7+
first_order_tableau() = [ARS111, ARS121]
88

99
#####
1010
##### 2nd order
1111
#####
1212

13-
const second_order_tableau = [
13+
second_order_tableau() = [
1414
ARS122,
1515
ARS232,
1616
ARS222,
@@ -31,18 +31,18 @@ const second_order_tableau = [
3131
#####
3232
##### 3rd order
3333
#####
34-
const third_order_tableau = [ARS233, ARS343, ARS443, IMKG342a, IMKG343a, DBM453]
34+
third_order_tableau() = [ARS233, ARS343, ARS443, IMKG342a, IMKG343a, DBM453]
3535

3636
import OrdinaryDiffEq as ODE
3737
import ClimaTimeSteppers as CTS
3838
ODE.alg_order(alg::CTS.IMEXARKAlgorithm) = ODE.alg_order(alg.tab)
3939

40-
for m in first_order_tableau
40+
for m in first_order_tableau()
4141
@eval ODE.alg_order(::$m) = 1
4242
end
43-
for m in second_order_tableau
43+
for m in second_order_tableau()
4444
@eval ODE.alg_order(::$m) = 2
4545
end
46-
for m in third_order_tableau
46+
for m in third_order_tableau()
4747
@eval ODE.alg_order(::$m) = 3
4848
end

test/convergence_utils.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ on the set of `dt` values in `dts`. Extra `kwargs` are passed to `solve`
2626
`solution` should be a function with a method `solution(u0, p, t)`.
2727
"""
2828
function convergence_errors(prob, sol, method, dts; kwargs...)
29+
hide_warning = (; kwargshandle = DiffEqBase.KeywordArgSilent)
2930
errs = map(dts) do dt
3031
# copy the problem so we don't mutate u0
3132
prob_copy = deepcopy(prob)
32-
u = solve(prob_copy, method; dt = dt, saveat = (prob.tspan[2],), kwargs...)
33+
u = solve(prob_copy, method; dt = dt, saveat = (prob.tspan[2],), kwargs..., hide_warning...)
3334
norm(u .- sol(prob.u0, prob.p, prob.tspan[end]))
3435
end
3536
return errs
@@ -66,13 +67,12 @@ function test_convergence_order!(test_case, tab, results = Dict(); refinement_ra
6667
refinement_range, # ::UnitRange, 2:4 is more fine than 1:3
6768
)
6869
computed_order = maximum(cr.computed_order)
69-
results[tab, test_case.test_name] = (; expected_order, computed_order)
70+
results[test_case.test_name, typeof(alg)] = (; expected_order, computed_order)
7071
return nothing
7172
end
7273

73-
distance(theoretic, computed) = abs(computed - theoretic) / theoretic
74-
pass_conv(theoretic, computed) = distance(theoretic, computed) * 100 < 10
75-
fail_conv(theoretic, computed) = !pass_conv(computed, theoretic) && !super_conv(theoretic, computed)
74+
pass_conv(theoretic, computed) = abs(computed - theoretic) / theoretic * 100 < 10 && computed > 0
75+
fail_conv(theoretic, computed) = !pass_conv(theoretic, computed) && !super_conv(theoretic, computed)
7676
super_conv(theoretic, computed) = (computed - theoretic) / theoretic * 100 > 10
7777

7878
#= Calls `test_convergence_order!` for each combination of test case
@@ -89,31 +89,32 @@ function convergence_order_results(tabs, test_cases)
8989
return results
9090
end
9191

92-
function tabulate_convergence_orders(test_cases, tabs, results)
93-
columns = map(test_cases) do test_case
94-
map(tab -> results[tab, test_case.test_name], tabs)
92+
function tabulate_convergence_orders(prob_names, algs, results, expected_orders; tabs = nothing)
93+
data = hcat(map(prob_names) do name
94+
map(alg -> results[name, typeof(alg)], algs)
95+
end...)
96+
alg_names = if tabs nothing
97+
@. string(nameof(typeof(tabs)))
98+
else
99+
@. string(typeof(algs))
95100
end
96-
expected_order = map(tab -> default_expected_order(nothing, tab), tabs)
97-
tab_names = map(tab -> "$tab ($(default_expected_order(nothing, tab)))", tabs)
98-
data = hcat(columns...)
99-
summary(result) = result.computed_order
101+
summary(result) = last(result)
100102
data_summary = map(d -> summary(d), data)
101103

102-
table_data = hcat(tab_names, data_summary)
104+
table_data = hcat(alg_names, data_summary)
103105
precentage_fail = sum(fail_conv.(getindex.(data, 1), getindex.(data, 2))) / length(data) * 100
104106
@info "Percentage of failed convergence order tests: $precentage_fail"
105107
fail_conv_hl = PrettyTables.Highlighter(
106-
(data, i, j) -> j 1 && fail_conv(expected_order[i], data[i, j]),
108+
(data, i, j) -> j 1 && fail_conv(expected_orders[i], data[i, j]),
107109
PrettyTables.crayon"red bold",
108110
)
109111
super_conv_hl = PrettyTables.Highlighter(
110-
(data, i, j) -> j 1 && super_conv(expected_order[i], data[i, j]),
112+
(data, i, j) -> j 1 && super_conv(expected_orders[i], data[i, j]),
111113
PrettyTables.crayon"yellow bold",
112114
)
113115
tab_column_hl = PrettyTables.Highlighter((data, i, j) -> j == 1, PrettyTables.crayon"green bold")
114-
test_case_names = map(test_case -> test_case.test_name, test_cases)
115116

116-
header = (["Tableau (theoretic)", test_case_names...],
117+
header = (["Tableau (theoretic)", prob_names...],
117118
# ["", ["" for tc in test_case_names]...],
118119
)
119120

test/ode_tests_basic.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ Qexact = exactsolution(finaltime, q0, t0)
109109
for (n, dt) in enumerate(dts)
110110
Q .= Qinit
111111
prob = IncrementingODEProblem(rhs!, Q, (t0, finaltime))
112-
solve(prob, method; dt=dt, adjustfinal=true)
112+
solve(prob, method; dt=dt)
113113
errors[n] = norm(Q - Qexact)
114114
end
115115
rates = log2.(errors[1:(end - 1)] ./ errors[2:end])
@@ -129,7 +129,7 @@ Qexact = exactsolution(finaltime, q0, t0)
129129
# rhs_arg! =
130130
# split_explicit_implicit ? rhs_nonlinear! : rhs!
131131
prob = SplitODEProblem(rhs_linear!, rhs_nonlinear!, Q, (t0, finaltime))
132-
solve(prob, method(DirectSolver); dt = dt, adjustfinal = true)
132+
solve(prob, method(DirectSolver); dt = dt)
133133
errors[n] = norm(Q - Qexact)
134134
@show (log2(dt), norm(Q - Qexact))
135135
end
@@ -160,7 +160,7 @@ Qexact = exactsolution(finaltime, q0, t0)
160160
# rhs_arg! =
161161
# split_explicit_implicit ? rhs_nonlinear! : rhs!
162162
prob = ODEProblem(ODEFunction(rhs!, jvp = rhs_linear!), Q, (t0, finaltime))
163-
solve(prob, method(DirectSolver); dt = dt, adjustfinal = true)
163+
solve(prob, method(DirectSolver); dt = dt)
164164
errors[n] = norm(Q - Qexact)
165165
@show (log2(dt), norm(Q - Qexact))
166166
end

0 commit comments

Comments
 (0)