Skip to content

Commit 7b77ba4

Browse files
Merge pull request #146 from ErikQQY/qqy/handle_unsuccessful_solve
Reject non successful retcodes in BVP benchmarks
2 parents d121722 + bc8727c commit 7b77ba4

File tree

2 files changed

+62
-56
lines changed

2 files changed

+62
-56
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqDevTools"
22
uuid = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.45.0"
4+
version = "2.45.1"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

src/benchmark.jl

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -330,70 +330,76 @@ function WorkPrecision(prob::AbstractBVProblem, alg, abstols, reltols, dts = not
330330
end
331331

332332
stats[i] = sol.stats
333-
334-
if haskey(kwargs, :prob_choice)
335-
cur_appxsol = appxsol[kwargs[:prob_choice]]
336-
elseif prob isa AbstractArray
337-
cur_appxsol = appxsol[1]
338-
else
339-
cur_appxsol = appxsol
340-
end
341-
342-
if cur_appxsol !== nothing
343-
errsol = appxtrue(sol, cur_appxsol)
344-
errors[i] = Dict{Symbol, Float64}()
345-
for err in keys(errsol.errors)
346-
errors[i][err] = mean(errsol.errors[err])
347-
end
348-
else
349-
errors[i] = Dict{Symbol, Float64}()
350-
for err in keys(errsol.errors)
351-
errors[i][err] = mean(errsol.errors[err])
333+
if SciMLBase.successful_retcode(sol)
334+
if haskey(kwargs, :prob_choice)
335+
cur_appxsol = appxsol[kwargs[:prob_choice]]
336+
elseif prob isa AbstractArray
337+
cur_appxsol = appxsol[1]
338+
else
339+
cur_appxsol = appxsol
352340
end
353-
end
354341

355-
benchmark_f = let dts = dts, _prob = _prob, alg = alg, sol = sol,
356-
abstols = abstols, reltols = reltols, kwargs = kwargs
357-
358-
if dts === nothing
359-
if _prob isa DAEProblem
360-
() -> @elapsed solve(_prob, alg;
361-
abstol = abstols[i],
362-
reltol = reltols[i],
363-
timeseries_errors = false,
364-
dense_errors = false, kwargs...)
365-
else
366-
() -> @elapsed solve(_prob, alg;
367-
abstol = abstols[i],
368-
reltol = reltols[i],
369-
timeseries_errors = false,
370-
dense_errors = false, kwargs...)
342+
if cur_appxsol !== nothing
343+
errsol = appxtrue(sol, cur_appxsol)
344+
errors[i] = Dict{Symbol, Float64}()
345+
for err in keys(errsol.errors)
346+
errors[i][err] = mean(errsol.errors[err])
371347
end
372348
else
373-
if _prob isa DAEProblem
374-
() -> @elapsed solve(_prob, alg;
375-
abstol = abstols[i],
376-
reltol = reltols[i],
377-
dt = dts[i],
378-
timeseries_errors = false,
379-
dense_errors = false, kwargs...)
349+
errors[i] = Dict{Symbol, Float64}()
350+
for err in keys(errsol.errors)
351+
errors[i][err] = mean(errsol.errors[err])
352+
end
353+
end
354+
355+
benchmark_f = let dts = dts, _prob = _prob, alg = alg, sol = sol,
356+
abstols = abstols, reltols = reltols, kwargs = kwargs
357+
358+
if dts === nothing
359+
if _prob isa DAEProblem
360+
() -> @elapsed solve(_prob, alg;
361+
abstol = abstols[i],
362+
reltol = reltols[i],
363+
timeseries_errors = false,
364+
dense_errors = false, kwargs...)
365+
else
366+
() -> @elapsed solve(_prob, alg;
367+
abstol = abstols[i],
368+
reltol = reltols[i],
369+
timeseries_errors = false,
370+
dense_errors = false, kwargs...)
371+
end
380372
else
381-
() -> @elapsed solve(_prob, alg;
382-
abstol = abstols[i],
383-
reltol = reltols[i],
384-
dt = dts[i],
385-
timeseries_errors = false,
386-
dense_errors = false, kwargs...)
373+
if _prob isa DAEProblem
374+
() -> @elapsed solve(_prob, alg;
375+
abstol = abstols[i],
376+
reltol = reltols[i],
377+
dt = dts[i],
378+
timeseries_errors = false,
379+
dense_errors = false, kwargs...)
380+
else
381+
() -> @elapsed solve(_prob, alg;
382+
abstol = abstols[i],
383+
reltol = reltols[i],
384+
dt = dts[i],
385+
timeseries_errors = false,
386+
dense_errors = false, kwargs...)
387+
end
387388
end
388389
end
389-
end
390-
benchmark_f() # pre-compile
390+
benchmark_f() # pre-compile
391391

392-
b_t = benchmark_f()
393-
if b_t > seconds
394-
times[i] = b_t
392+
b_t = benchmark_f()
393+
if b_t > seconds
394+
times[i] = b_t
395+
else
396+
times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t)
397+
end
395398
else
396-
times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t)
399+
# Unsuccessful retcode, give NaN error and time
400+
errors[i] = Dict(
401+
:l∞ => NaN, :L2 => NaN, :final => NaN, :l2 => NaN, :L∞ => NaN)
402+
times[i] = NaN
397403
end
398404
end
399405
end

0 commit comments

Comments
 (0)