|
1 | | -import JLD2 |
2 | | -import Plots |
| 1 | +using ClimaTimeSteppers |
| 2 | +using InteractiveUtils: subtypes |
3 | 3 | using Distributions: quantile, TDist |
4 | | -using Printf: @sprintf |
5 | | -using LaTeXStrings: latexstring |
6 | | -import DiffEqCallbacks |
7 | | -import ClimaTimeSteppers as CTS |
| 4 | +using LinearAlgebra: norm |
8 | 5 |
|
9 | | -function get_algorithm_names() |
10 | | - all_subtypes(::Type{T}) where {T} = isabstracttype(T) ? vcat(all_subtypes.(subtypes(T))...) : [T] |
11 | | - algorithm_names = map(T -> T(), all_subtypes(ClimaTimeSteppers.AbstractAlgorithmName)) |
12 | | - return filter(name -> !(name isa ARK437L2SA1 || name isa ARK548L2SA2), algorithm_names) # too high order |
13 | | -end |
14 | | - |
15 | | -function get_imex_ssprk_algorithm_names() |
16 | | - all_subtypes(::Type{T}) where {T} = isabstracttype(T) ? vcat(all_subtypes.(subtypes(T))...) : [T] |
17 | | - algorithm_names = map(T -> T(), all_subtypes(ClimaTimeSteppers.IMEXSSPRKAlgorithmName)) |
18 | | - return algorithm_names |
19 | | -end |
20 | | - |
21 | | -function make_saving_callback(cb, u, t, integrator) |
22 | | - savevalType = typeof(cb(u, t, integrator)) |
23 | | - return DiffEqCallbacks.SavingCallback(cb, DiffEqCallbacks.SavedValues(typeof(t), savevalType)) |
24 | | -end |
| 6 | +all_subtypes(::Type{T}) where {T} = isabstracttype(T) ? vcat(all_subtypes.(subtypes(T))...) : [T] |
25 | 7 |
|
26 | 8 | """ |
27 | 9 | imex_convergence_orders(algorithm_name) |
@@ -60,8 +42,7 @@ imex_convergence_orders(::ARK548L2SA2) = (5, 5, 5) |
60 | 42 | imex_convergence_orders(::SSP22Heuns) = (2, 2, 2) |
61 | 43 | imex_convergence_orders(::SSP33ShuOsher) = (3, 3, 3) |
62 | 44 | imex_convergence_orders(::RK4) = (4, 4, 4) |
63 | | -# SSPKnoth is not really an IMEX method |
64 | | -imex_convergence_orders(::SSPKnoth) = (2, 2, 2) |
| 45 | +imex_convergence_orders(::SSPKnoth) = (2, 3, 2) |
65 | 46 |
|
66 | 47 | # Compute a confidence interval for the convergence order, returning the |
67 | 48 | # estimated convergence order and its uncertainty. |
@@ -103,192 +84,136 @@ function (assuming that the algorithm converges). |
103 | 84 | function predicted_convergence_order(algorithm_name::AbstractAlgorithmName, ode_function::AbstractClimaODEFunction) |
104 | 85 | (imp_order, exp_order, combined_order) = imex_convergence_orders(algorithm_name) |
105 | 86 | has_imp = !isnothing(ode_function.T_imp!) |
106 | | - has_exp = CTS.has_T_exp(ode_function) |
| 87 | + has_exp = ClimaTimeSteppers.has_T_exp(ode_function) |
107 | 88 | has_imp && !has_exp && return imp_order |
108 | 89 | !has_imp && has_exp && return exp_order |
109 | 90 | has_imp && has_exp && return combined_order |
110 | 91 | return 0 |
111 | 92 | end |
112 | 93 |
|
113 | | -function export_convergence_results(alg_name, test_problem, num_steps; kwargs...) |
114 | | - out_dict = Dict() |
115 | | - (; test_name) = test_problem |
116 | | - out_dict[string(test_name)] = Dict() |
117 | | - out_dict[string(test_name)][string(alg_name)] = Dict() |
118 | | - out_dict[string(test_name)]["args"] = (alg_name, test_problem, num_steps) |
119 | | - out_dict[string(test_name)]["kwargs"] = kwargs |
120 | | - compute_convergence!(out_dict, alg_name, test_problem, num_steps; kwargs...) |
121 | | - JLD2.save_object("convergence_$(alg_name)_$(test_problem.test_name).jld2", out_dict) |
122 | | -end |
123 | | - |
| 94 | +""" |
| 95 | + algorithm(algorithm_name, [linear_implicit]) |
124 | 96 |
|
125 | | -function compute_convergence!( |
126 | | - out_dict, |
127 | | - alg_name, |
| 97 | +Generates an appropriate `DistributedODEAlgorithm` from an `AbstractAlgorithmName`. |
| 98 | +For `IMEXAlgorithmNames`, `linear_implicit` must also be specified. One Newton |
| 99 | +iteration is used for linear implicit problems, and two iterations are used for |
| 100 | +nonlinear implicit problems. |
| 101 | +""" |
| 102 | +algorithm(algorithm_name, _) = algorithm(algorithm_name) |
| 103 | +algorithm(algorithm_name::ClimaTimeSteppers.ERKAlgorithmName) = ExplicitAlgorithm(algorithm_name) |
| 104 | +algorithm(algorithm_name::ClimaTimeSteppers.SSPKnoth) = |
| 105 | + ClimaTimeSteppers.RosenbrockAlgorithm(ClimaTimeSteppers.tableau(ClimaTimeSteppers.SSPKnoth())) |
| 106 | +algorithm(algorithm_name::ClimaTimeSteppers.IMEXARKAlgorithmName, linear_implicit) = |
| 107 | + IMEXAlgorithm(algorithm_name, NewtonsMethod(; max_iters = linear_implicit ? 1 : 2)) |
| 108 | + |
| 109 | +rms(array) = norm(array) / sqrt(length(array)) |
| 110 | +rms_error(u, t, sol) = rms(abs.(u .- sol(t))) |
| 111 | + |
| 112 | +function test_convergence!( |
| 113 | + convergence_results, |
| 114 | + algorithm_name, |
128 | 115 | test_case, |
129 | | - num_steps; |
130 | | - num_steps_scaling_factor = 10, |
131 | | - order_confidence_percent = 99, |
132 | | - super_convergence = (), |
| 116 | + default_num_steps; |
| 117 | + super_convergence_algorithm_names = (), |
| 118 | + num_steps_scaling_factor = 4, |
| 119 | + high_order_sample_shifts = 1, |
133 | 120 | numerical_reference_algorithm_name = nothing, |
134 | | - numerical_reference_num_steps = num_steps_scaling_factor^3 * num_steps, |
135 | | - full_history_algorithm_name = nothing, |
136 | | - average_function = array -> norm(array) / sqrt(length(array)), |
137 | | - average_function_str = "RMS", |
138 | | - only_endpoints = false, |
| 121 | + numerical_reference_num_steps = num_steps_scaling_factor^3 * default_num_steps, |
| 122 | + broken_tests = (), |
| 123 | + error_on_failure = true, |
139 | 124 | verbose = false, |
140 | 125 | ) |
141 | 126 | (; test_name, t_end, linear_implicit, analytic_sol) = test_case |
142 | 127 | prob = test_case.split_prob |
143 | | - FT = typeof(t_end) |
144 | | - default_dt = t_end / num_steps |
145 | | - key1 = string(test_name) |
146 | | - key2 = string(alg_name) |
147 | | - |
148 | | - algorithm(algorithm_name::ClimaTimeSteppers.ERKAlgorithmName) = ExplicitAlgorithm(algorithm_name) |
149 | | - algorithm(algorithm_name::ClimaTimeSteppers.SSPKnoth) = |
150 | | - ClimaTimeSteppers.RosenbrockAlgorithm(ClimaTimeSteppers.tableau(ClimaTimeSteppers.SSPKnoth())) |
151 | | - algorithm(algorithm_name::ClimaTimeSteppers.IMEXARKAlgorithmName) = |
152 | | - IMEXAlgorithm(algorithm_name, NewtonsMethod(; max_iters = linear_implicit ? 1 : 2)) |
153 | 128 |
|
| 129 | + default_dt = t_end / default_num_steps |
154 | 130 | ref_sol = if isnothing(numerical_reference_algorithm_name) |
155 | 131 | analytic_sol |
156 | 132 | else |
157 | | - ref_alg = algorithm(numerical_reference_algorithm_name) |
| 133 | + # TODO: Do not regenerate the reference solution for every algorithm!! |
158 | 134 | ref_alg_str = string(nameof(typeof(numerical_reference_algorithm_name))) |
| 135 | + ref_alg = algorithm(numerical_reference_algorithm_name, linear_implicit) |
159 | 136 | ref_dt = t_end / numerical_reference_num_steps |
160 | | - verbose && |
161 | | - @info "Generating numerical reference solution for $test_name with $ref_alg_str (dt = $ref_dt)..." |
162 | | - sol = solve(deepcopy(prob), ref_alg; dt = ref_dt, save_everystep = !only_endpoints) |
163 | | - out_dict[string(test_name)]["numerical_ref_sol"] = sol |
| 137 | + verbose && @info "Generating reference solution for $test_name with $ref_alg_str and dt = $ref_dt" |
| 138 | + solve(deepcopy(prob), ref_alg; dt = ref_dt, save_everystep = true) |
164 | 139 | end |
165 | | - |
166 | | - cur_avg_err(u, t) = average_function(abs.(u .- ref_sol(t))) |
167 | | - cur_avg_sol_and_err(u, t) = (average_function(u), average_function(abs.(u .- ref_sol(t)))) |
168 | | - |
169 | | - float_str(x) = @sprintf "%.4f" x |
170 | | - pow_str(x) = "10^{$(@sprintf "%.1f" log10(x))}" |
171 | | - function si_str(x) |
172 | | - if isnan(x) || x in (0, Inf, -Inf) |
173 | | - return string(x) |
174 | | - end |
175 | | - exponent = floor(Int, log10(x)) |
176 | | - mantissa = x / 10.0^exponent |
177 | | - return "$(float_str(mantissa)) \\times 10^{$exponent}" |
| 140 | + numerical_reference_info = if isnothing(numerical_reference_algorithm_name) |
| 141 | + nothing |
| 142 | + else |
| 143 | + ref_average_rms_error = rms(rms_error.(ref_sol.u, ref_sol.t, (analytic_sol,))) |
| 144 | + (ref_alg_str, ref_dt, ref_average_rms_error) |
178 | 145 | end |
179 | 146 |
|
180 | | - net_avg_sol_str = "\\textrm{$average_function_str}\\_\\textrm{solution}" |
181 | | - net_avg_err_str = "\\textrm{$average_function_str}\\_\\textrm{error}" |
182 | | - cur_avg_sol_str = "\\textrm{current}\\_$net_avg_sol_str" |
183 | | - cur_avg_err_str = "\\textrm{current}\\_$net_avg_err_str" |
184 | | - |
185 | | - linestyles = (:solid, :dash, :dot, :dashdot, :dashdotdot) |
186 | | - marker_kwargs = (; markershape = :circle, markeralpha = 0.5, markerstrokewidth = 0) |
187 | | - plot_kwargs = (; |
188 | | - legendposition = :outerright, |
189 | | - legendtitlefontpointsize = 8, |
190 | | - palette = :glasbey_bw_minc_20_maxl_70_n256, |
191 | | - size = (1000, 2000), # size in px |
192 | | - leftmargin = 60Plots.px, |
193 | | - rightmargin = 0Plots.px, |
194 | | - topmargin = 0Plots.px, |
195 | | - bottommargin = 30Plots.px, |
196 | | - ) |
197 | | - |
198 | | - plot1_dts = t_end ./ round.(Int, num_steps .* num_steps_scaling_factor .^ (-1:0.5:1)) |
199 | | - plot1 = Plots.plot(; |
200 | | - title = "Convergence Orders", |
201 | | - xaxis = (latexstring("dt"), :log10), |
202 | | - yaxis = (latexstring(net_avg_err_str), :log10), |
203 | | - legendtitle = "Convergence Order ($order_confidence_percent% CI)", |
204 | | - plot_kwargs..., |
205 | | - ) |
206 | | - |
207 | | - plot2b_min = typemax(FT) |
208 | | - plot2b_max = typemin(FT) |
209 | | - plot2a = Plots.plot(; |
210 | | - title = latexstring("Solutions with \$dt = $(pow_str(default_dt))\$"), |
211 | | - xaxis = (latexstring("t"),), |
212 | | - yaxis = (latexstring(cur_avg_sol_str),), |
213 | | - legendtitle = latexstring(net_avg_sol_str), |
214 | | - plot_kwargs..., |
215 | | - ) |
216 | | - plot2b = Plots.plot(; |
217 | | - title = latexstring("Errors with \$dt = $(pow_str(default_dt))\$"), |
218 | | - xaxis = (latexstring("t"),), |
219 | | - yaxis = (latexstring(cur_avg_err_str), :log10), |
220 | | - legendtitle = latexstring(net_avg_err_str), |
221 | | - plot_kwargs..., |
222 | | - ) |
223 | | - |
224 | | - cur_avg_errs_dict = Dict() |
225 | | - # for algorithm_name in algorithm_names |
226 | | - algorithm_name = alg_name |
227 | | - alg = algorithm(algorithm_name) |
228 | 147 | alg_str = string(nameof(typeof(algorithm_name))) |
229 | | - predicted_order = predicted_convergence_order(algorithm_name, prob.f) |
230 | | - linestyle = linestyles[(predicted_order - 1) % length(linestyles) + 1] |
| 148 | + alg = algorithm(algorithm_name, linear_implicit) |
| 149 | + verbose && @info "Testing convergence of $alg_str for $test_name" |
231 | 150 |
|
232 | | - verbose && @info "Running $test_name with $alg_str..." |
233 | | - @info "Using plot1_dts=$plot1_dts" |
234 | | - plot1_net_avg_errs = map(plot1_dts) do plot1_dt |
235 | | - plot1_sol = solve(deepcopy(prob), alg; dt = plot1_dt, save_everystep = !only_endpoints) |
236 | | - (; u, t) = plot1_sol |
237 | | - cur_avg_errs = cur_avg_err.(u, t) |
238 | | - cur_avg_errs_dict[plot1_dt] = cur_avg_errs |
239 | | - verbose && @info "RMS_error(dt = $plot1_dt) = $(average_function(cur_avg_errs))" |
240 | | - return average_function(cur_avg_errs) |
| 151 | + predicted_order = predicted_convergence_order(algorithm_name, prob.f) |
| 152 | + predicted_super_convergence = algorithm_name in super_convergence_algorithm_names |
| 153 | + num_steps_powers = (-1:0.5:1) .- high_order_sample_shifts * max(0, predicted_order - 3) / 2 |
| 154 | + sampled_num_steps = default_num_steps .* num_steps_scaling_factor .^ num_steps_powers |
| 155 | + sampled_dts = t_end ./ round.(Int, sampled_num_steps) |
| 156 | + average_rms_errors = map(sampled_dts) do dt |
| 157 | + sol = solve(deepcopy(prob), alg; dt = dt, save_everystep = true) |
| 158 | + rms(rms_error.(sol.u, sol.t, (ref_sol,))) |
241 | 159 | end |
242 | | - out_dict[key1][key2]["cur_avg_errs_dict"] = cur_avg_errs_dict |
243 | | - order, order_uncertainty = convergence_order(plot1_dts, plot1_net_avg_errs, order_confidence_percent / 100) |
244 | | - order_str = "$(float_str(order)) \\pm $(float_str(order_uncertainty))" |
245 | | - if algorithm_name in super_convergence |
246 | | - predicted_order += 1 |
247 | | - plot1_label = "$alg_str: \$$order_str\\ \\ \\ \\textbf{\\textit{SC}}\$" |
| 160 | + verbose && @info "Sampled timesteps = $sampled_dts" |
| 161 | + verbose && @info "Average RMS errors = $average_rms_errors" |
| 162 | + |
| 163 | + # Compute a 99% confidence interval for the convergence order |
| 164 | + order, order_uncertainty = convergence_order(sampled_dts, average_rms_errors, 0.99) |
| 165 | + verbose && @info "Convergence order = $order ± $order_uncertainty" |
| 166 | + actual_predicted_order = predicted_order + Bool(predicted_super_convergence) |
| 167 | + convergence_test_error = if isnan(order) |
| 168 | + "Timestepper does not converge for $alg_str ($test_name)" |
| 169 | + elseif abs(order - actual_predicted_order) > order_uncertainty |
| 170 | + "Predicted order outside error bars for $alg_str ($test_name)" |
| 171 | + elseif order_uncertainty > actual_predicted_order / 10 |
| 172 | + "Order uncertainty too large for $alg_str ($test_name)" |
248 | 173 | else |
249 | | - plot1_label = "$alg_str: \$$order_str\$" |
250 | | - end |
251 | | - verbose && @info "Order = $order ± $order_uncertainty" |
252 | | - if abs(order - predicted_order) > order_uncertainty |
253 | | - @warn "Predicted order outside error bars for $alg_str ($test_name)" |
254 | | - end |
255 | | - if order_uncertainty > predicted_order / 10 |
256 | | - @warn "Order uncertainty too large for $alg_str ($test_name)" |
| 174 | + nothing |
257 | 175 | end |
258 | | - |
259 | | - # Remove all 0s from plot2_cur_avg_errs because they cannot be plotted on a |
260 | | - # logarithmic scale. Record the extrema of plot2_cur_avg_errs to set ylim. |
261 | | - plot2_data = solve(deepcopy(prob), alg; dt = default_dt, save_everystep = true) |
262 | | - if any(isnan, plot2_data) |
263 | | - error("NaN found in plot2_data in problem $(test_name)") |
| 176 | + if isnothing(convergence_test_error) |
| 177 | + @assert !(algorithm_name in broken_tests) |
| 178 | + elseif error_on_failure && !(algorithm_name in broken_tests) |
| 179 | + error(convergence_test_error) |
| 180 | + else |
| 181 | + @warn convergence_test_error |
264 | 182 | end |
265 | | - (; u, t) = plot2_data |
266 | | - cur_sols_and_errs = cur_avg_sol_and_err.(u, t) |
267 | | - out_dict[key1][key2]["plot2_data"] = (; u = cur_sols_and_errs, t) |
268 | 183 |
|
269 | | - if !isnothing(full_history_algorithm_name) |
270 | | - history_alg = algorithm(full_history_algorithm_name) |
271 | | - history_alg_name = string(nameof(typeof(full_history_algorithm_name))) |
272 | | - history_solve_sol = solve(deepcopy(prob), history_alg; dt = default_dt, save_everystep = true) |
273 | | - (; u, t) = history_solve_sol |
274 | | - history_solve_results = map(X -> X[1] .- ref_sol(X[2]), zip(u, t)) |
275 | | - history_solve_results = (; u = history_solve_results, t) |
276 | | - out_dict[key1][key2]["history_solve_results"] = history_solve_results |
277 | | - end |
278 | | - return out_dict |
| 184 | + default_dt_sol = solve(deepcopy(prob), alg; dt = default_dt, save_everystep = true) |
| 185 | + default_dt_times = default_dt_sol.t |
| 186 | + default_dt_solutions = rms.(default_dt_sol.u) |
| 187 | + default_dt_errors = rms_error.(default_dt_sol.u, default_dt_sol.t, (ref_sol,)) |
| 188 | + |
| 189 | + convergence_results[test_name] = Dict() |
| 190 | + convergence_results[test_name]["default_dt"] = default_dt |
| 191 | + convergence_results[test_name]["numerical_reference_info"] = numerical_reference_info |
| 192 | + convergence_results[test_name]["all_alg_results"] = Dict() |
| 193 | + convergence_results[test_name]["all_alg_results"][alg_str] = Dict() |
| 194 | + alg_results = convergence_results[test_name]["all_alg_results"][alg_str] |
| 195 | + alg_results["predicted_order"] = predicted_order |
| 196 | + alg_results["predicted_super_convergence"] = predicted_super_convergence |
| 197 | + alg_results["sampled_dts"] = sampled_dts |
| 198 | + alg_results["average_rms_errors"] = average_rms_errors |
| 199 | + alg_results["default_dt_times"] = default_dt_times |
| 200 | + alg_results["default_dt_solutions"] = default_dt_solutions |
| 201 | + alg_results["default_dt_errors"] = default_dt_errors |
| 202 | + return convergence_results |
279 | 203 | end |
280 | 204 |
|
281 | | -function test_unconstrained_vs_ssp_without_limiters(alg_name, test_case, num_steps) |
| 205 | +function test_unconstrained_vs_ssp_without_limiters(algorithm_name, test_case, num_steps) |
282 | 206 | prob = test_case.split_prob |
283 | 207 | dt = test_case.t_end / num_steps |
284 | 208 | newtons_method = NewtonsMethod(; max_iters = test_case.linear_implicit ? 1 : 2) |
285 | | - algorithm = IMEXAlgorithm(alg_name, newtons_method) |
286 | | - reference_algorithm = IMEXAlgorithm(alg_name, newtons_method, Unconstrained()) |
| 209 | + algorithm = IMEXAlgorithm(algorithm_name, newtons_method) |
| 210 | + reference_algorithm = IMEXAlgorithm(algorithm_name, newtons_method, Unconstrained()) |
287 | 211 | solution = solve(deepcopy(prob), algorithm; dt).u[end] |
288 | 212 | reference_solution = solve(deepcopy(prob), reference_algorithm; dt).u[end] |
289 | | - if norm(solution .- reference_solution) / norm(reference_solution) > 30 * eps(Float64) |
290 | | - alg_str = string(nameof(typeof(alg_name))) |
291 | | - @warn "Unconstrained and SSP versions of $alg_str \ |
292 | | - give different results for $(test_case.test_name)" |
| 213 | + relative_error = norm(solution .- reference_solution) / norm(reference_solution) |
| 214 | + if relative_error > 100 * eps(Float64) |
| 215 | + error("Unconstrained and SSP versions of $algorithm_name give \ |
| 216 | + different results for $(test_case.test_name): relative \ |
| 217 | + error = $(round(Int, relative_error / eps(Float64))) * eps") |
293 | 218 | end |
294 | 219 | end |
0 commit comments