Skip to content

Commit a931473

Browse files
Refactor tests, prep for tab-aware alg
1 parent d6bc696 commit a931473

File tree

6 files changed

+29
-23
lines changed

6 files changed

+29
-23
lines changed

docs/src/algo_comparisons.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ include(joinpath(cts_dir, "test", "problems.jl"))
1515
tab2 = (tab2..., IMKG254b, IMKG254c, HOMMEM1)
1616
tab3 = (ARS233, ARS343, ARS443, IMKG342a, IMKG343a, DBM453)
1717
tabs = [tab1..., tab2..., tab3...]
18+
tabs = map(t -> t(), tabs)
1819
test_algs("IMEX ARK", tabs, ark_analytic_nonlin_test_cts(Float64), 400)
1920
test_algs("IMEX ARK", tabs, ark_analytic_sys_test_cts(Float64), 60)
20-
test_algs("IMEX ARK", tabs, ark_analytic_test_cts(Float64), 16000; super_convergence = ARS121)
21+
test_algs("IMEX ARK", tabs, ark_analytic_test_cts(Float64), 16000; super_convergence = ARS121())
2122
end

docs/src/plotting_utils.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,15 @@ function test_algs(
8686
analytic_end_sol = [analytic_sols[end]]
8787

8888
for tab in tableaus
89-
(prob, alg) = problem_algo(test_case, tab)
89+
prob = problem(test_case, tab)
90+
alg = algorithm(test_case, tab)
9091
predicted_order = if super_convergence == tab
91-
CTS.theoretical_convergence_order(tab()) + 1
92+
CTS.theoretical_convergence_order(tab) + 1
9293
else
93-
CTS.theoretical_convergence_order(tab())
94+
CTS.theoretical_convergence_order(tab)
9495
end
9596
linestyle = linestyles[(predicted_order - 1) % length(linestyles) + 1]
96-
alg_name = string(nameof(tab))
97+
alg_name = string(nameof(typeof(tab)))
9798

9899
# Use tstops to fix saving issues due to machine precision (e.g. if the
99100
# integrator needs to save at t but it stops at t - eps(), it will skip

test/convergence.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ function tabulate_convergence_orders()
7272
IMKG343a,
7373
DBM453,
7474
]
75+
tabs = map(t -> t(), tabs)
7576
test_cases = all_test_cases(Float64)
7677
results = convergence_order_results(tabs, test_cases)
7778
tabulate_convergence_orders(test_cases, tabs, results)

test/convergence_utils.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ function default_expected_order(alg, tab)
5959
end
6060

6161
function test_convergence_order!(test_case, tab, results = Dict(); refinement_range)
62-
prob, alg = problem_algo(test_case, tab)
63-
expected_order = default_expected_order(alg, tab())
62+
prob = problem(test_case, tab)
63+
alg = algorithm(test_case, tab)
64+
expected_order = default_expected_order(alg, tab)
6465
cr = OCT.refinement_study(
6566
prob,
6667
alg;
@@ -96,8 +97,8 @@ function tabulate_convergence_orders(test_cases, tabs, results)
9697
columns = map(test_cases) do test_case
9798
map(tab -> results[tab, test_case.test_name], tabs)
9899
end
99-
expected_order = map(tab -> default_expected_order(nothing, tab()), tabs)
100-
tab_names = map(tab -> "$tab ($(default_expected_order(nothing, tab())))", tabs)
100+
expected_order = map(tab -> default_expected_order(nothing, tab), tabs)
101+
tab_names = map(tab -> "$tab ($(default_expected_order(nothing, tab)))", tabs)
101102
data = hcat(columns...)
102103
summary(result) = result.computed_order
103104
data_summary = map(d -> summary(d), data)

test/problems.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,11 @@ function IntegratorTestCase(;
274274
jac_prototype = Matrix{FT}(undef, length(Y₀), length(Y₀))
275275
func_args = (; jac_prototype, Wfact = Wfact!, tgrad = tgrad!)
276276
tendency_func = ODEFunction(tendency!; func_args...)
277-
if isnothing(implicit_tendency!) # assume that related args are also nothing
277+
split_tendency_func = if isnothing(implicit_tendency!) # assume that related args are also nothing
278278
no_tendency!(Yₜ, Y, _, t) = Yₜ .= 0
279-
split_tendency_func = SplitFunction(tendency_func, no_tendency!)
279+
SplitFunction(tendency_func, no_tendency!)
280280
else
281-
split_tendency_func = SplitFunction(ODEFunction(implicit_tendency!; func_args...), explicit_tendency!)
281+
SplitFunction(ODEFunction(implicit_tendency!; func_args...), explicit_tendency!)
282282
end
283283
make_prob(func) = ODEProblem(func, Y₀, (FT(0), t_end), nothing)
284284
IntegratorTestCase(
@@ -307,12 +307,14 @@ function ClimaIntegratorTestCase(;
307307
jac_prototype = Matrix{FT}(undef, length(Y₀), length(Y₀))
308308
func_args = (; jac_prototype, Wfact = Wfact!, tgrad = tgrad!)
309309
tendency_func = ClimaODEFunction(; T_imp! = ODEFunction(tendency!; func_args...))
310-
if isnothing(implicit_tendency!) # assume that related args are also nothing
311-
split_tendency_func = ClimaODEFunction(; T_imp! = ODEFunction(tendency!; func_args...))
310+
311+
T_imp! = if isnothing(implicit_tendency!)
312+
# assume that related args are also nothing
313+
ODEFunction(tendency!; func_args...)
312314
else
313-
split_tendency_func =
314-
ClimaODEFunction(; T_exp! = explicit_tendency!, T_imp! = ODEFunction(implicit_tendency!; func_args...))
315+
ODEFunction(implicit_tendency!; func_args...)
315316
end
317+
split_tendency_func = ClimaODEFunction(; T_exp! = explicit_tendency!, T_imp!)
316318
make_prob(func) = ODEProblem(func, Y₀, (FT(0), t_end), nothing)
317319
IntegratorTestCase(
318320
test_name,

test/utils.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import ClimaTimeSteppers as CTS
22
using Test
33

4-
function problem_algo(test_case, tab)
5-
if tab() isa CTS.AbstractIMEXARKTableau
4+
problem(test_case, tab::CTS.AbstractIMEXARKTableau) = test_case.split_prob
5+
problem(test_case, tab) = test_case.prob
6+
7+
function algorithm(test_case, tab)
8+
return if tab isa CTS.AbstractIMEXARKTableau
69
max_iters = test_case.linear_implicit ? 1 : 2 # TODO: is 2 enough?
7-
alg = CTS.IMEXARKAlgorithm(tab(), NewtonsMethod(; max_iters))
8-
prob = test_case.split_prob
10+
CTS.IMEXARKAlgorithm(tab, NewtonsMethod(; max_iters))
911
else
10-
alg = tab()
11-
prob = test_case.prob
12+
tab
1213
end
13-
return (prob, alg)
1414
end

0 commit comments

Comments
 (0)