@@ -163,8 +163,8 @@ function compute_convergence!(
163163 out_dict[string (test_name)][" numerical_ref_sol" ] = sol
164164 end
165165
166- cur_avg_err (u, t, integrator ) = average_function (abs .(u .- ref_sol (t)))
167- cur_avg_sol_and_err (u, t, integrator ) = (average_function (u), average_function (abs .(u .- ref_sol (t))))
166+ cur_avg_err (u, t) = average_function (abs .(u .- ref_sol (t)))
167+ cur_avg_sol_and_err (u, t) = (average_function (u), average_function (abs .(u .- ref_sol (t))))
168168
169169 float_str (x) = @sprintf " %.4f" x
170170 pow_str (x) = " 10^{$(@sprintf " %.1f" log10 (x)) }"
@@ -221,9 +221,6 @@ function compute_convergence!(
221221 plot_kwargs... ,
222222 )
223223
224- scb_cur_avg_err = make_saving_callback (cur_avg_err, prob. u0, t_end, nothing )
225- scb_cur_avg_sol_and_err = make_saving_callback (cur_avg_sol_and_err, prob. u0, t_end, nothing )
226-
227224 cur_avg_errs_dict = Dict ()
228225 # for algorithm_name in algorithm_names
229226 algorithm_name = alg_name
@@ -235,8 +232,9 @@ function compute_convergence!(
235232 verbose && @info " Running $test_name with $alg_str ..."
236233 @info " Using plot1_dts=$plot1_dts "
237234 plot1_net_avg_errs = map (plot1_dts) do plot1_dt
238- cur_avg_errs =
239- solve (deepcopy (prob), alg; dt = plot1_dt, save_everystep = ! only_endpoints, callback = scb_cur_avg_err). u
235+ plot1_sol = solve (deepcopy (prob), alg; dt = plot1_dt, save_everystep = ! only_endpoints)
236+ (; u, t) = plot1_sol
237+ cur_avg_errs = cur_avg_err .(u, t)
240238 cur_avg_errs_dict[plot1_dt] = cur_avg_errs
241239 verbose && @info " RMS_error(dt = $plot1_dt ) = $(average_function (cur_avg_errs)) "
242240 return average_function (cur_avg_errs)
@@ -260,29 +258,21 @@ function compute_convergence!(
260258
261259 # Remove all 0s from plot2_cur_avg_errs because they cannot be plotted on a
262260 # logarithmic scale. Record the extrema of plot2_cur_avg_errs to set ylim.
263- plot2_values = solve (
264- deepcopy (prob),
265- alg;
266- dt = default_dt,
267- save_everystep = ! only_endpoints,
268- callback = scb_cur_avg_sol_and_err,
269- )
270- if any (isnan, plot2_values)
271- error (" NaN found in plot2_values in problem $(test_name) " )
261+ plot2_data = solve (deepcopy (prob), alg; dt = default_dt, save_everystep = true )
262+ if any (isnan, plot2_data)
263+ error (" NaN found in plot2_data in problem $(test_name) " )
272264 end
273- out_dict[key1][key2][" plot2_values" ] = plot2_values
274-
265+ (; u, t) = plot2_data
266+ cur_sols_and_errs = cur_avg_sol_and_err .(u, t)
267+ out_dict[key1][key2][" plot2_data" ] = (; u = cur_sols_and_errs, t)
275268
276269 if ! isnothing (full_history_algorithm_name)
277270 history_alg = algorithm (full_history_algorithm_name)
278271 history_alg_name = string (nameof (typeof (full_history_algorithm_name)))
279- history_solve_results = solve (
280- deepcopy (prob),
281- history_alg;
282- dt = default_dt,
283- save_everystep = ! only_endpoints,
284- callback = make_saving_callback ((u, t, integrator) -> u .- ref_sol (t), prob. u0, t_end, nothing ),
285- )
272+ history_solve_sol = solve (deepcopy (prob), history_alg; dt = default_dt, save_everystep = true )
273+ (; u, t) = history_solve_sol
274+ history_solve_results = map (X -> X[1 ] .- ref_sol (X[2 ]), zip (u, t))
275+ history_solve_results = (; u = history_solve_results, t)
286276 out_dict[key1][key2][" history_solve_results" ] = history_solve_results
287277 end
288278 return out_dict
0 commit comments