Skip to content

Commit d01dd9b

Browse files
Merge pull request #142 from SciML/wps
Make work-precision benchmarks more robust to failures
2 parents c085b60 + 026bd4b commit d01dd9b

File tree

5 files changed

+66
-59
lines changed

5 files changed

+66
-59
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqDevTools"
22
uuid = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.44.3"
4+
version = "2.44.4"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
@@ -28,7 +28,7 @@ Distributed = "1.9"
2828
LinearAlgebra = "1.9"
2929
Logging = "1.9"
3030
NLsolve = "4.2"
31-
NonlinearSolve = "1, 2"
31+
NonlinearSolve = "3.13"
3232
ODEProblemLibrary = "0.1"
3333
OrdinaryDiffEq = "6"
3434
ParameterizedFunctions = "5"

src/benchmark.jl

Lines changed: 60 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -217,69 +217,75 @@ function WorkPrecision(prob, alg, abstols, reltols, dts = nothing;
217217

218218
stats[i] = sol.stats
219219

220-
if haskey(kwargs, :prob_choice)
221-
cur_appxsol = appxsol[kwargs[:prob_choice]]
222-
elseif prob isa AbstractArray
223-
cur_appxsol = appxsol[1]
224-
else
225-
cur_appxsol = appxsol
226-
end
227-
228-
if cur_appxsol !== nothing
229-
errsol = appxtrue(sol, cur_appxsol)
230-
errors[i] = Dict{Symbol, Float64}()
231-
for err in keys(errsol.errors)
232-
errors[i][err] = mean(errsol.errors[err])
233-
end
234-
else
235-
errors[i] = Dict{Symbol, Float64}()
236-
for err in keys(sol.errors)
237-
errors[i][err] = mean(sol.errors[err])
220+
if SciMLBase.successful_retcode(sol)
221+
if haskey(kwargs, :prob_choice)
222+
cur_appxsol = appxsol[kwargs[:prob_choice]]
223+
elseif prob isa AbstractArray
224+
cur_appxsol = appxsol[1]
225+
else
226+
cur_appxsol = appxsol
238227
end
239-
end
240-
241-
benchmark_f = let dts = dts, _prob = _prob, alg = alg, sol = sol,
242-
abstols = abstols, reltols = reltols, kwargs = kwargs
243228

244-
if dts === nothing
245-
if _prob isa DAEProblem
246-
() -> @elapsed solve(_prob, alg, sol.u, sol.t;
247-
abstol = abstols[i],
248-
reltol = reltols[i],
249-
timeseries_errors = false,
250-
dense_errors = false, kwargs...)
251-
else
252-
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
253-
abstol = abstols[i],
254-
reltol = reltols[i],
255-
timeseries_errors = false,
256-
dense_errors = false, kwargs...)
229+
if cur_appxsol !== nothing
230+
errsol = appxtrue(sol, cur_appxsol)
231+
errors[i] = Dict{Symbol, Float64}()
232+
for err in keys(errsol.errors)
233+
errors[i][err] = mean(errsol.errors[err])
257234
end
258235
else
259-
if _prob isa DAEProblem
260-
() -> @elapsed solve(_prob, alg, sol.u, sol.t;
261-
abstol = abstols[i],
262-
reltol = reltols[i],
263-
dt = dts[i],
264-
timeseries_errors = false,
265-
dense_errors = false, kwargs...)
236+
errors[i] = Dict{Symbol, Float64}()
237+
for err in keys(sol.errors)
238+
errors[i][err] = mean(sol.errors[err])
239+
end
240+
end
241+
242+
benchmark_f = let dts = dts, _prob = _prob, alg = alg, sol = sol,
243+
abstols = abstols, reltols = reltols, kwargs = kwargs
244+
245+
if dts === nothing
246+
if _prob isa DAEProblem
247+
() -> @elapsed solve(_prob, alg, sol.u, sol.t;
248+
abstol = abstols[i],
249+
reltol = reltols[i],
250+
timeseries_errors = false,
251+
dense_errors = false, kwargs...)
252+
else
253+
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
254+
abstol = abstols[i],
255+
reltol = reltols[i],
256+
timeseries_errors = false,
257+
dense_errors = false, kwargs...)
258+
end
266259
else
267-
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
268-
abstol = abstols[i],
269-
reltol = reltols[i],
270-
dt = dts[i],
271-
timeseries_errors = false,
272-
dense_errors = false, kwargs...)
260+
if _prob isa DAEProblem
261+
() -> @elapsed solve(_prob, alg, sol.u, sol.t;
262+
abstol = abstols[i],
263+
reltol = reltols[i],
264+
dt = dts[i],
265+
timeseries_errors = false,
266+
dense_errors = false, kwargs...)
267+
else
268+
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
269+
abstol = abstols[i],
270+
reltol = reltols[i],
271+
dt = dts[i],
272+
timeseries_errors = false,
273+
dense_errors = false, kwargs...)
274+
end
273275
end
274276
end
275-
end
276-
benchmark_f() # pre-compile
277+
benchmark_f() # pre-compile
277278

278-
b_t = benchmark_f()
279-
if b_t > seconds
280-
times[i] = b_t
279+
b_t = benchmark_f()
280+
if b_t > seconds
281+
times[i] = b_t
282+
else
283+
times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t)
284+
end
281285
else
282-
times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t)
286+
# Unsuccessful retcode, give NaN time
287+
errors[i] = Dict(:l∞ => NaN, :L2 => NaN, :final => NaN, :l2 => NaN, :L∞ => NaN)
288+
times[i] = NaN
283289
end
284290
end
285291
end

src/plotrecipes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ end
110110
ys = [get_val_from_wp(wp, y) for wp in wp_set.wps]
111111
xguide --> key_to_label(x)
112112
yguide --> key_to_label(y)
113+
legend --> :outerright
113114
label --> reshape(wp_set.names, 1, length(wp_set))
114115
return xs, ys
115116
elseif view == :dt_convergence

src/test_solution.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ function appxtrue(sim::EnsembleSolution, appx_setup; kwargs...)
117117
for i in eachindex(sim)
118118
prob = sim[i].prob
119119
prob2 = SDEProblem(prob.f, prob.g, prob.u0, prob.tspan,
120-
noise = NoiseWrapper(sim[i].W))
120+
noise = NoiseWrapper(sim.u[i].W))
121121
true_sol = solve(prob2, appx_setup[:alg]; appx_setup...)
122-
_new_sols[i] = appxtrue(sim[i], true_sol)
122+
_new_sols[i] = appxtrue(sim.u[i], true_sol)
123123
end
124124
new_sols = convert(Vector{typeof(_new_sols[1])}, _new_sols)
125125
calculate_ensemble_errors(new_sols; converged = sim.converged,

test/benchmark_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ setups = [Dict(:alg => RK4()); Dict(:alg => Euler()); Dict(:alg => BS3());
2323
t1 = @elapsed sol = solve(prob, RK4(), dt = 1 / 2^(4))
2424
t2 = @elapsed sol2 = solve(prob, setups[1][:alg], dt = 1 / 2^(4))
2525

26-
@test (sol2[end] == sol[end])
26+
@test (sol2.u[end] == sol.u[end])
2727

2828
test_sol_2Dlinear = TestSolution(
2929
solve(prob_ode_2Dlinear, Vern7(), abstol = 1 / 10^14, reltol = 1 / 10^14))

0 commit comments

Comments
 (0)