Skip to content

Commit 1f6e259

Browse files
Merge pull request #358 from CliMA/ck/plotting
Fix convergence plots
2 parents 92478ea + 0d171e4 commit 1f6e259

File tree

2 files changed

+27
-37
lines changed

2 files changed

+27
-37
lines changed

docs/src/dev/compute_convergence.jl

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ function compute_convergence!(
163163
out_dict[string(test_name)]["numerical_ref_sol"] = sol
164164
end
165165

166-
cur_avg_err(u, t, integrator) = average_function(abs.(u .- ref_sol(t)))
167-
cur_avg_sol_and_err(u, t, integrator) = (average_function(u), average_function(abs.(u .- ref_sol(t))))
166+
cur_avg_err(u, t) = average_function(abs.(u .- ref_sol(t)))
167+
cur_avg_sol_and_err(u, t) = (average_function(u), average_function(abs.(u .- ref_sol(t))))
168168

169169
float_str(x) = @sprintf "%.4f" x
170170
pow_str(x) = "10^{$(@sprintf "%.1f" log10(x))}"
@@ -221,9 +221,6 @@ function compute_convergence!(
221221
plot_kwargs...,
222222
)
223223

224-
scb_cur_avg_err = make_saving_callback(cur_avg_err, prob.u0, t_end, nothing)
225-
scb_cur_avg_sol_and_err = make_saving_callback(cur_avg_sol_and_err, prob.u0, t_end, nothing)
226-
227224
cur_avg_errs_dict = Dict()
228225
# for algorithm_name in algorithm_names
229226
algorithm_name = alg_name
@@ -235,8 +232,9 @@ function compute_convergence!(
235232
verbose && @info "Running $test_name with $alg_str..."
236233
@info "Using plot1_dts=$plot1_dts"
237234
plot1_net_avg_errs = map(plot1_dts) do plot1_dt
238-
cur_avg_errs =
239-
solve(deepcopy(prob), alg; dt = plot1_dt, save_everystep = !only_endpoints, callback = scb_cur_avg_err).u
235+
plot1_sol = solve(deepcopy(prob), alg; dt = plot1_dt, save_everystep = !only_endpoints)
236+
(; u, t) = plot1_sol
237+
cur_avg_errs = cur_avg_err.(u, t)
240238
cur_avg_errs_dict[plot1_dt] = cur_avg_errs
241239
verbose && @info "RMS_error(dt = $plot1_dt) = $(average_function(cur_avg_errs))"
242240
return average_function(cur_avg_errs)
@@ -260,29 +258,21 @@ function compute_convergence!(
260258

261259
# Remove all 0s from plot2_cur_avg_errs because they cannot be plotted on a
262260
# logarithmic scale. Record the extrema of plot2_cur_avg_errs to set ylim.
263-
plot2_values = solve(
264-
deepcopy(prob),
265-
alg;
266-
dt = default_dt,
267-
save_everystep = !only_endpoints,
268-
callback = scb_cur_avg_sol_and_err,
269-
)
270-
if any(isnan, plot2_values)
271-
error("NaN found in plot2_values in problem $(test_name)")
261+
plot2_data = solve(deepcopy(prob), alg; dt = default_dt, save_everystep = true)
262+
if any(isnan, plot2_data)
263+
error("NaN found in plot2_data in problem $(test_name)")
272264
end
273-
out_dict[key1][key2]["plot2_values"] = plot2_values
274-
265+
(; u, t) = plot2_data
266+
cur_sols_and_errs = cur_avg_sol_and_err.(u, t)
267+
out_dict[key1][key2]["plot2_data"] = (; u = cur_sols_and_errs, t)
275268

276269
if !isnothing(full_history_algorithm_name)
277270
history_alg = algorithm(full_history_algorithm_name)
278271
history_alg_name = string(nameof(typeof(full_history_algorithm_name)))
279-
history_solve_results = solve(
280-
deepcopy(prob),
281-
history_alg;
282-
dt = default_dt,
283-
save_everystep = !only_endpoints,
284-
callback = make_saving_callback((u, t, integrator) -> u .- ref_sol(t), prob.u0, t_end, nothing),
285-
)
272+
history_solve_sol = solve(deepcopy(prob), history_alg; dt = default_dt, save_everystep = true)
273+
(; u, t) = history_solve_sol
274+
history_solve_results = map(X -> X[1] .- ref_sol(X[2]), zip(u, t))
275+
history_solve_results = (; u = history_solve_results, t)
286276
out_dict[key1][key2]["history_solve_results"] = history_solve_results
287277
end
288278
return out_dict

docs/src/dev/summarize_convergence.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ function algorithm_names_by_availability(out_dict, test_name, algorithm_names_al
7777
end
7878

7979
function summarize_convergence(
80-
out_dict,
80+
out_dict_test,
8181
alg_name,
8282
test_case,
8383
num_steps;
@@ -100,11 +100,11 @@ function summarize_convergence(
100100

101101
keep_alg = true
102102
plot1_dts = t_end ./ round.(Int, num_steps .* num_steps_scaling_factor .^ (-1:0.5:1))
103-
algorithm_names = algorithm_names_by_availability(out_dict, test_name, algorithm_names_all, plot1_dts)
103+
algorithm_names = algorithm_names_by_availability(out_dict_test, test_name, algorithm_names_all, plot1_dts)
104104
@show algorithm_names
105105

106-
# out_dict = Dict()
107-
# out_dict[key2] = Dict()
106+
# out_dict_test = Dict()
107+
# out_dict_test[key2] = Dict()
108108

109109
prob = test_case.split_prob
110110
FT = typeof(t_end)
@@ -124,7 +124,7 @@ function summarize_convergence(
124124
ref_dt = t_end / numerical_reference_num_steps
125125
verbose &&
126126
@info "Generating numerical reference solution for $test_name with $ref_alg_str (dt = $ref_dt)..."
127-
out_dict["numerical_ref_sol"] # solve(deepcopy(prob), ref_alg; dt = ref_dt, save_everystep = !only_endpoints)
127+
out_dict_test["numerical_ref_sol"] # solve(deepcopy(prob), ref_alg; dt = ref_dt, save_everystep = !only_endpoints)
128128
end
129129

130130
cur_avg_err(u, t, integrator) = average_function(abs.(u .- ref_sol(t)))
@@ -188,7 +188,7 @@ function summarize_convergence(
188188
scb_cur_avg_sol_and_err = make_saving_callback(cur_avg_sol_and_err, prob.u0, t_end, nothing)
189189

190190
for algorithm_name in algorithm_names
191-
cur_avg_errs_dict = out_dict[string(algorithm_name)]["cur_avg_errs_dict"]
191+
cur_avg_errs_dict = out_dict_test[string(algorithm_name)]["cur_avg_errs_dict"]
192192
alg = algorithm(algorithm_name)
193193
alg_str = string(nameof(typeof(algorithm_name)))
194194
predicted_order = predicted_convergence_order(algorithm_name, prob.f)
@@ -226,17 +226,17 @@ function summarize_convergence(
226226

227227
# Remove all 0s from plot2_cur_avg_errs because they cannot be plotted on a
228228
# logarithmic scale. Record the extrema of plot2_cur_avg_errs to set ylim.
229-
plot2_values = out_dict[string(algorithm_name)]["plot2_values"]
230-
# plot2_values = solve(
229+
plot2_data = out_dict_test[string(algorithm_name)]["plot2_data"]
230+
# plot2_data = solve(
231231
# deepcopy(prob),
232232
# alg;
233233
# dt = default_dt,
234234
# save_everystep = !only_endpoints,
235235
# callback = scb_cur_avg_sol_and_err,
236236
# )
237-
plot2_ts = plot2_values.t
238-
plot2_cur_avg_sols = first.(plot2_values.u)
239-
plot2_cur_avg_errs = last.(plot2_values.u)
237+
plot2_ts = plot2_data.t
238+
plot2_cur_avg_sols = first.(plot2_data.u)
239+
plot2_cur_avg_errs = last.(plot2_data.u)
240240
plot2b_min = min(plot2b_min, minimum(x -> x == 0 ? typemax(FT) : x, plot2_cur_avg_errs))
241241
plot2b_max = max(plot2b_max, maximum(plot2_cur_avg_errs))
242242
plot2_cur_avg_errs .= max.(plot2_cur_avg_errs, eps(FT(0)))
@@ -267,7 +267,7 @@ function summarize_convergence(
267267
if !isnothing(full_history_algorithm_name)
268268
history_alg = algorithm(full_history_algorithm_name)
269269
history_alg_name = string(nameof(typeof(full_history_algorithm_name)))
270-
history_solve_results = out_dict[history_alg_name]["history_solve_results"]
270+
history_solve_results = out_dict_test[history_alg_name]["history_solve_results"]
271271
# history_solve_results = solve(
272272
# deepcopy(prob),
273273
# history_alg;

0 commit comments

Comments
 (0)