Skip to content

Commit 1bca376

Browse files
Fixes
1 parent 0d6f225 commit 1bca376

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

docs/src/plotting_utils.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ function convergence_order(dts, errs, confidence)
9191
return order, order_uncertainty
9292
end
9393

94+
function make_saving_callback(cb, u, t, integrator)
95+
DECB = CTS.DiffEqCallbacks
96+
savevalType = typeof(cb(u, t, integrator))
97+
return DECB.SavingCallback(cb, DECB.SavedValues(typeof(t), savevalType))
98+
end
99+
94100
function verify_convergence(
95101
title,
96102
algorithm_names,
@@ -127,8 +133,8 @@ function verify_convergence(
127133
solve(deepcopy(prob), ref_alg; dt = ref_dt, save_everystep = !only_endpoints)
128134
end
129135

130-
cur_avg_err(u, t) = average_function(abs.(u .- ref_sol(t)))
131-
cur_avg_sol_and_err(u, t) = (average_function(u), average_function(abs.(u .- ref_sol(t))))
136+
cur_avg_err(u, t, integrator) = average_function(abs.(u .- ref_sol(t)))
137+
cur_avg_sol_and_err(u, t, integrator) = (average_function(u), average_function(abs.(u .- ref_sol(t))))
132138

133139
float_str(x) = @sprintf "%.4f" x
134140
pow_str(x) = "10^{$(@sprintf "%.1f" log10(x))}"
@@ -182,6 +188,9 @@ function verify_convergence(
182188
plot_kwargs...,
183189
)
184190

191+
scb_cur_avg_err = make_saving_callback(cur_avg_err, prob.u0, t_end, nothing)
192+
scb_cur_avg_sol_and_err = make_saving_callback(cur_avg_sol_and_err, prob.u0, t_end, nothing)
193+
185194
for algorithm_name in algorithm_names
186195
alg = algorithm(algorithm_name)
187196
alg_str = string(nameof(typeof(algorithm_name)))
@@ -196,8 +205,7 @@ function verify_convergence(
196205
alg;
197206
dt = plot1_dt,
198207
save_everystep = !only_endpoints,
199-
save_func = cur_avg_err,
200-
kwargshandle = DiffEqBase.KeywordArgSilent,
208+
callback = scb_cur_avg_err,
201209
).u
202210
verbose && @info "RMS_error(dt = $plot1_dt) = $(average_function(cur_avg_errs))"
203211
return average_function(cur_avg_errs)
@@ -226,8 +234,7 @@ function verify_convergence(
226234
alg;
227235
dt = default_dt,
228236
save_everystep = !only_endpoints,
229-
save_func = cur_avg_sol_and_err,
230-
kwargshandle = DiffEqBase.KeywordArgSilent,
237+
callback = scb_cur_avg_sol_and_err,
231238
)
232239
plot2_ts = plot2_values.t
233240
plot2_cur_avg_sols = first.(plot2_values.u)
@@ -267,8 +274,7 @@ function verify_convergence(
267274
history_alg;
268275
dt = default_dt,
269276
save_everystep = !only_endpoints,
270-
save_func = (u, t) -> u .- ref_sol(t),
271-
kwargshandle = DiffEqBase.KeywordArgSilent,
277+
callback = make_saving_callback((u, t, integrator) -> u .- ref_sol(t), prob.u0, t_end, nothing),
272278
)
273279
history_array = hcat(history_solve_results.u...)
274280
history_plot_title = "Errors for $history_alg_name with \$dt = $(pow_str(default_dt))\$"

src/integrators.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ mutable struct DistributedODEIntegrator{
3737
# DiffEqBase.initialize! and DiffEqBase.finalize!
3838
cache::cacheType
3939
sol::solType
40+
tdir::tType # see https://docs.sciml.ai/DiffEqCallbacks/stable/output_saving/#DiffEqCallbacks.SavingCallback
4041
end
4142

4243
# helper function for setting up min/max heaps for tstops and saveat
@@ -64,6 +65,8 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat)
6465
return tstops, saveat
6566
end
6667

68+
compute_tdir(ts) = ts[1] > ts[end] ? sign(ts[end] - ts[1]) : eltype(ts)(1)
69+
6770
# called by DiffEqBase.init and DiffEqBase.solve
6871
function DiffEqBase.__init(
6972
prob::DiffEqBase.AbstractODEProblem,
@@ -75,9 +78,10 @@ function DiffEqBase.__init(
7578
save_everystep = false,
7679
callback = nothing,
7780
advance_to_tstop = false,
78-
save_func = (u, t) -> copy(u), # custom kwarg
79-
dtchangeable = true, # custom kwarg
80-
stepstop = -1, # custom kwarg
81+
save_func = (u, t) -> copy(u), # custom kwarg
82+
dtchangeable = true, # custom kwarg
83+
stepstop = -1, # custom kwarg
84+
tdir = compute_tdir(prob.tspan), #
8185
kwargs...,
8286
)
8387
(; u0, p) = prob
@@ -116,6 +120,7 @@ function DiffEqBase.__init(
116120
false,
117121
init_cache(prob, alg; dt, kwargs...),
118122
sol,
123+
tdir,
119124
)
120125
if prob.f isa ClimaODEFunction
121126
(; post_explicit!) = prob.f

0 commit comments

Comments
 (0)