Skip to content

Commit 618f182

Browse files
Merge #97
97: Modularize remaining unit tests r=charleskawczynski a=charleskawczynski This PR: - Changes the `test/problems.jl` script to instead define functions - Qualifies some methods - Moves more tests into safetestsets - Includes the utility functions only in the portion of the unit tests where needed. Co-authored-by: Charles Kawczynski <[email protected]>
2 parents b356b75 + 5baa8cd commit 618f182

File tree

10 files changed

+201
-172
lines changed

10 files changed

+201
-172
lines changed

docs/src/algorithms.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ The convergence orders of the provided methods are verified using test cases fro
2525
using Pkg
2626
Pkg.activate("../../test")
2727
Pkg.instantiate()
28-
include("../../test/problems.jl")
29-
include("../../test/utils.jl")
3028
include("../../test/convergence.jl")
3129
Pkg.activate(".")
3230
```

perf/flame.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ function do_work!(integrator, not_generated_cache)
2424
end
2525
problem_str = parsed_args["problem"]
2626
prob = if problem_str=="ode_fun"
27-
split_linear_prob_wfact_split
27+
split_linear_prob_wfact_split()
2828
elseif problem_str=="fe"
29-
split_linear_prob_wfact_split_fe
29+
split_linear_prob_wfact_split_fe()
3030
else
3131
error("Bad option")
3232
end

perf/jet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ function config_integrators(problem)
2424
return (; integrator_generated=integrator, not_generated_integrator)
2525
end
2626
prob = if parsed_args["problem"]=="ode_fun"
27-
split_linear_prob_wfact_split
27+
split_linear_prob_wfact_split()
2828
elseif parsed_args["problem"]=="fe"
29-
split_linear_prob_wfact_split_fe
29+
split_linear_prob_wfact_split_fe()
3030
else
3131
error("Bad option")
3232
end

test/compare_generated.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ include("problems.jl")
66
algorithm = ARS343(NewtonsMethod(; max_iters = 2))
77
dt = 0.01
88
for problem in (
9-
split_linear_prob_wfact_split,
10-
split_linear_prob_wfact_split_fe,
9+
split_linear_prob_wfact_split(),
10+
split_linear_prob_wfact_split_fe(),
1111
)
1212
integrator = DiffEqBase.init(deepcopy(problem), algorithm; dt)
1313
not_generated_integrator = deepcopy(integrator)

test/convergence.jl

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,46 @@
11
using ClimaTimeSteppers, LinearAlgebra, Test
22

3-
dts = 0.5 .^ (4:7)
4-
5-
for (prob, sol, tscale) in [
6-
(linear_prob, linear_sol, 1)
7-
(sincos_prob, sincos_sol, 1)
8-
]
9-
10-
@test convergence_order(prob, sol, LSRKEulerMethod(), dts.*tscale) 1 atol=0.1
11-
@test convergence_order(prob, sol, LSRK54CarpenterKennedy(), dts.*tscale) 4 atol=0.05
12-
@test convergence_order(prob, sol, LSRK144NiegemannDiehlBusch(), dts.*tscale) 4 atol=0.05
3+
include(joinpath(@__DIR__, "convergence_utils.jl"))
4+
include(joinpath(@__DIR__, "utils.jl"))
5+
include(joinpath(@__DIR__, "problems.jl"))
6+
7+
@testset "LSRK and SSP convergence" begin
8+
dts = 0.5 .^ (4:7)
9+
10+
for (prob, sol, tscale) in [
11+
(linear_prob(), linear_sol, 1)
12+
(sincos_prob(), sincos_sol, 1)
13+
]
14+
15+
@test convergence_order(prob, sol, LSRKEulerMethod(), dts.*tscale) 1 atol=0.1
16+
@test convergence_order(prob, sol, LSRK54CarpenterKennedy(), dts.*tscale) 4 atol=0.05
17+
@test convergence_order(prob, sol, LSRK144NiegemannDiehlBusch(), dts.*tscale) 4 atol=0.05
18+
19+
@test convergence_order(prob, sol, SSPRK22Heuns(), dts.*tscale) 2 atol=0.05
20+
@test convergence_order(prob, sol, SSPRK22Ralstons(), dts.*tscale) 2 atol=0.05
21+
@test convergence_order(prob, sol, SSPRK33ShuOsher(), dts.*tscale) 3 atol=0.05
22+
@test convergence_order(prob, sol, SSPRK34SpiteriRuuth(), dts.*tscale) 3 atol=0.05
23+
end
1324

14-
@test convergence_order(prob, sol, SSPRK22Heuns(), dts.*tscale) 2 atol=0.05
15-
@test convergence_order(prob, sol, SSPRK22Ralstons(), dts.*tscale) 2 atol=0.05
16-
@test convergence_order(prob, sol, SSPRK33ShuOsher(), dts.*tscale) 3 atol=0.05
17-
@test convergence_order(prob, sol, SSPRK34SpiteriRuuth(), dts.*tscale) 3 atol=0.05
18-
end
25+
for (prob, sol, tscale) in [
26+
(linear_prob_wfactt(), linear_sol, 1)
27+
]
28+
@test convergence_order(prob, sol, SSPKnoth(linsolve=linsolve_direct), dts.*tscale) 2 atol=0.05
1929

20-
for (prob, sol, tscale) in [
21-
(linear_prob_wfactt, linear_sol, 1)
22-
]
23-
@test convergence_order(prob, sol, SSPKnoth(linsolve=linsolve_direct), dts.*tscale) 2 atol=0.05
30+
end
2431

25-
end
2632

33+
# ForwardEulerODEFunction
34+
for (prob, sol, tscale) in [
35+
(linear_prob_fe(), linear_sol, 1)
36+
(sincos_prob_fe(), sincos_sol, 1)
37+
]
38+
@test convergence_order(prob, sol, SSPRK22Heuns(), dts.*tscale) 2 atol=0.05
39+
@test convergence_order(prob, sol, SSPRK22Ralstons(), dts.*tscale) 2 atol=0.05
40+
@test convergence_order(prob, sol, SSPRK33ShuOsher(), dts.*tscale) 3 atol=0.05
41+
@test convergence_order(prob, sol, SSPRK34SpiteriRuuth(), dts.*tscale) 3 atol=0.05
2742

28-
# ForwardEulerODEFunction
29-
for (prob, sol, tscale) in [
30-
(linear_prob_fe, linear_sol, 1)
31-
(sincos_prob_fe, sincos_sol, 1)
32-
]
33-
@test convergence_order(prob, sol, SSPRK22Heuns(), dts.*tscale) 2 atol=0.05
34-
@test convergence_order(prob, sol, SSPRK22Ralstons(), dts.*tscale) 2 atol=0.05
35-
@test convergence_order(prob, sol, SSPRK33ShuOsher(), dts.*tscale) 3 atol=0.05
36-
@test convergence_order(prob, sol, SSPRK34SpiteriRuuth(), dts.*tscale) 3 atol=0.05
37-
43+
end
3844
end
3945

4046
ENV["GKSwstype"] = "nul" # avoid displaying plots
@@ -46,14 +52,14 @@ ENV["GKSwstype"] = "nul" # avoid displaying plots
4652
algs2 = (algs2..., IMKG254b, IMKG254c, HOMMEM1)
4753
algs3 = (ARS233, ARS343, ARS443, IMKG342a, IMKG343a, DBM453)
4854
dict = Dict(((algs1 .=> 1)..., (algs2 .=> 2)..., (algs3 .=> 3)...))
49-
test_algs("IMEX ARK", dict, ark_analytic_nonlin_test, 400)
50-
test_algs("IMEX ARK", dict, ark_analytic_sys_test, 60)
55+
test_algs("IMEX ARK", dict, ark_analytic_nonlin_test(Float64), 400)
56+
test_algs("IMEX ARK", dict, ark_analytic_sys_test(Float64), 60)
5157

5258
# For some bizarre reason, ARS121 converges with order 2 for ark_analytic,
5359
# even though it is only a 1st order method.
5460
dict′ = copy(dict)
5561
dict′[ARS121] = 2
56-
test_algs("IMEX ARK", dict′, ark_analytic_test, 16000)
62+
test_algs("IMEX ARK", dict′, ark_analytic_test(Float64), 16000)
5763
end
5864

5965
#=

test/convergence_utils.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
2+
"""
3+
DirectSolver
4+
5+
A linear solver which forms the full matrix of a linear operator and its LU factorization.
6+
"""
7+
struct DirectSolver end
8+
9+
DirectSolver(args...) = DirectSolver()
10+
11+
function (::DirectSolver)(x,A,b,matrix_updated; kwargs...)
12+
n = length(x)
13+
M = mapslices(y -> mul!(similar(y), A, y), Matrix{eltype(x)}(I,n,n), dims=1)
14+
x .= M \ b
15+
end
16+
17+
"""
18+
convergence_rates(problem, solution, method, dts; kwargs...)
19+
20+
Compute the errors rates of `method` on `problem` by comparing to `solution`
21+
on the set of `dt` values in `dts`. Extra `kwargs` are passed to `solve`
22+
23+
`solution` should be a function with a method `solution(u0, p, t)`.
24+
"""
25+
function convergence_errors(prob, sol, method, dts; kwargs...)
26+
errs = map(dts) do dt
27+
# copy the problem so we don't mutate u0
28+
prob_copy = deepcopy(prob)
29+
u = solve(prob_copy, method; dt=dt, saveat=(prob.tspan[2],), kwargs...)
30+
norm(u .- sol(prob.u0, prob.p, prob.tspan[end]))
31+
end
32+
return errs
33+
end
34+
35+
36+
"""
37+
convergence_order(problem, solution, method, dts; kwargs...)
38+
39+
Estimate the order of the rate of convergence of `method` on `problem` by comparing to
40+
`solution` the set of `dt` values in `dts`.
41+
42+
`solution` should be a function with a method `solution(u0, p, t)`.
43+
"""
44+
function convergence_order(prob, sol, method, dts; kwargs...)
45+
errs = convergence_errors(prob, sol, method, dts; kwargs...)
46+
# find slope coefficient in log scale
47+
_,order_est = hcat(ones(length(dts)), log2.(dts)) \ log2.(errs)
48+
return order_est
49+
end

test/integrator.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import OrdinaryDiffEq
44
include("problems.jl")
55

66
@testset "integrator save times" begin
7-
test_case = constant_tendency_test
7+
test_case = constant_tendency_test(Float64)
88
(; prob, analytic_sol) = test_case
99
for alg in (SSPRK33ShuOsher(), OrdinaryDiffEq.SSPRK33()),
1010
reverse_prob in (false, true),
@@ -121,7 +121,7 @@ end
121121
# OrdinaryDiffEq does not save at t0′ after reinit! unless erase_sol is
122122
# true, so this test does not include a comparison with OrdinaryDiffEq.
123123
alg = SSPRK33ShuOsher()
124-
test_case = constant_tendency_test
124+
test_case = constant_tendency_test(Float64)
125125
(; prob, analytic_sol) = test_case
126126
for reverse_prob in (false, true)
127127
if reverse_prob
@@ -161,7 +161,7 @@ end
161161

162162
@testset "integrator step past end time" begin
163163
alg = SSPRK33ShuOsher()
164-
test_case = constant_tendency_test
164+
test_case = constant_tendency_test(Float64)
165165
(; prob, analytic_sol) = test_case
166166
t0, tf = prob.tspan
167167
dt = tf - t0

0 commit comments

Comments
 (0)