Skip to content

Commit 4a31e6f

Browse files
Add helper func, rm unused increment test
1 parent 38a54a0 commit 4a31e6f

File tree

1 file changed

+14
-20
lines changed

1 file changed

+14
-20
lines changed

test/utils.jl

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,17 @@ import Plots, Printf
22
import ClimaTimeSteppers as CTS
33
using Test
44

5-
# Should we remove this?
6-
has_increment_formulation(::CTS.AbstractIMEXARKTableau) = false
5+
function problem_algo(test_case, tab)
6+
if tab() isa CTS.AbstractIMEXARKTableau
7+
max_iters = test_case.linear_implicit ? 1 : 2 # TODO: is 2 enough?
8+
alg = CTS.IMEXARKAlgorithm(tab(), NewtonsMethod(; max_iters))
9+
prob = test_case.split_prob
10+
else
11+
alg = tab()
12+
prob = test_case.prob
13+
end
14+
return (prob, alg)
15+
end
716

817
"""
918
test_algs(
@@ -89,16 +98,7 @@ function test_algs(
8998
analytic_end_sol = [analytic_sols[end]]
9099

91100
for tab in tableaus
92-
if tab() isa CTS.AbstractIMEXARKTableau
93-
max_iters = linear_implicit ? 1 : 2 # TODO: is 2 enough?
94-
alg = CTS.IMEXARKAlgorithm(tab(), NewtonsMethod(; max_iters))
95-
tendency_prob = test_case.split_prob
96-
increment_prob = test_case.split_increment_prob
97-
else
98-
alg = tab()
99-
tendency_prob = test_case.prob
100-
increment_prob = test_case.increment_prob
101-
end
101+
(prob, alg) = problem_algo(test_case, tab)
102102
predicted_order = if super_convergence == tab
103103
CTS.theoretical_convergence_order(tab()) + 1
104104
else
@@ -111,7 +111,7 @@ function test_algs(
111111
# integrator needs to save at t but it stops at t - eps(), it will skip
112112
# over saving at t, unless tstops forces it to round t - eps() to t).
113113
solve_args = (; dt = plot1_dt, saveat = plot1_saveat, tstops = plot1_saveat)
114-
tendency_sols = solve(deepcopy(tendency_prob), alg; solve_args...).u
114+
tendency_sols = solve(deepcopy(prob), alg; solve_args...).u
115115
tendency_norms = @. norm(tendency_sols)
116116
tendency_errs = @. norm(tendency_sols - analytic_sols)
117117
min_err = minimum(x -> x == 0 ? typemax(FT) : x, tendency_errs)
@@ -121,13 +121,7 @@ function test_algs(
121121
Plots.plot!(plot1a, plot1_saveat, tendency_norms; label = alg_name, linestyle)
122122
Plots.plot!(plot1b, plot1_saveat, tendency_errs; label = alg_name, linestyle)
123123

124-
if has_increment_formulation(tab())
125-
increment_sols = solve(deepcopy(increment_prob), alg; solve_args...).u
126-
increment_errs = @. norm(increment_sols - tendency_sols)
127-
@test maximum(increment_errs) < 1000 * eps(FT) broken = alg_name == "HOMMEM1" # TODO: why is this one broken?
128-
end
129-
130-
tendency_end_sols = map(dt -> solve(deepcopy(tendency_prob), alg; dt).u[end], plot2_dts)
124+
tendency_end_sols = map(dt -> solve(deepcopy(prob), alg; dt).u[end], plot2_dts)
131125
tendency_end_errs = @. norm(tendency_end_sols - analytic_end_sol)
132126
_, computed_order = hcat(ones(length(plot2_dts)), log10.(plot2_dts)) \ log10.(tendency_end_errs)
133127
@test computed_order predicted_order rtol = 0.1

0 commit comments

Comments
 (0)