Skip to content

Commit 435624f

Browse files
Merge pull request #21 from JuliaDiffEq/myb/timing
Use BenchmarkTools for timing
2 parents 42aad6f + 39557dd commit 435624f

File tree

2 files changed

+79
-50
lines changed

2 files changed

+79
-50
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ DiffEqPDEBase 0.4.0
66
NLsolve 0.14.1
77
DiffEqMonteCarlo
88
DiffEqNoiseProcess
9+
BenchmarkTools

src/benchmark.jl

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using BenchmarkTools
2+
using Statistics
13
## Shootouts
24

35
mutable struct Shootout
@@ -26,7 +28,9 @@ function ode_shootout(args...;kwargs...)
2628
ShootOut(args...;kwargs...)
2729
end
2830

29-
function Shootout(prob,setups;appxsol=nothing,numruns=20,names=nothing,error_estimate=:final,kwargs...)
31+
benchtime(bench) = BenchmarkTools.time(mean(bench))/1e9
32+
33+
function Shootout(prob,setups;appxsol=nothing,names=nothing,error_estimate=:final,numruns=20,seconds=2,kwargs...)
3034
N = length(setups)
3135
errors = Vector{Float64}(undef,N)
3236
solutions = Vector{Any}(undef,N)
@@ -43,10 +47,19 @@ function Shootout(prob,setups;appxsol=nothing,numruns=20,names=nothing,error_est
4347
dense_errors = dense_errors,kwargs...,setups[i]...) # Compile and get result
4448
sol = solve(prob,setups[i][:alg],sol.u,sol.t,sol.k;timeseries_errors=timeseries_errors,
4549
dense_errors = dense_errors,kwargs...,setups[i]...) # Compile and get result
46-
GC.gc()
47-
t = @elapsed for j in 1:numruns
48-
sol = solve(prob,setups[i][:alg],sol.u,sol.t,sol.k;
49-
kwargs...,setups[i]...,timeseries_errors=false,dense_errors=false)
50+
fails = 0
51+
local benchable
52+
@label START
53+
try
54+
benchable = @benchmarkable(solve($prob,$(setups[i][:alg]),$(sol.u),$(sol.t),$(sol.k);
55+
$kwargs...,$(setups[i])...,timeseries_errors=false,dense_errors=false))
56+
catch
57+
# sometimes BenchmarkTools errors with
58+
# `ERROR: syntax: function argument and static parameter names must be distinct`
59+
# so, we are catching that error and try a few times.
60+
fails += 1
61+
fails > 4 && rethrow()
62+
@goto START
5063
end
5164
if appxsol != nothing
5265
errsol = appxtrue(sol,appxsol)
@@ -56,8 +69,9 @@ function Shootout(prob,setups;appxsol=nothing,numruns=20,names=nothing,error_est
5669
errors[i] = sol.errors[error_estimate]
5770
solutions[i] = sol
5871
end
72+
bench = run(benchable, samples=numruns, seconds=seconds)
73+
t = benchtime(bench)
5974
effs[i] = 1/(errors[i]*t)
60-
t = t/numruns
6175
times[i] = t
6276
end
6377
for j in 1:N, i in 1:N
@@ -72,7 +86,7 @@ function ode_shootoutset(args...;kwargs...)
7286
ShootoutSet(args...;kwargs...)
7387
end
7488

75-
function ShootoutSet(probs,setups;probaux=nothing,numruns=20,
89+
function ShootoutSet(probs,setups;probaux=nothing,
7690
names=nothing,print_names=false,kwargs...)
7791
N = length(probs)
7892
shootouts = Vector{Shootout}(undef,N)
@@ -88,7 +102,7 @@ function ShootoutSet(probs,setups;probaux=nothing,numruns=20,
88102
end
89103
for i in eachindex(probs)
90104
print_names && println(names[i])
91-
shootouts[i] = Shootout(probs[i],setups;numruns=numruns,names=names,kwargs...,probaux[i]...)
105+
shootouts[i] = Shootout(probs[i],setups;names=names,kwargs...,probaux[i]...)
92106
winners[i] = shootouts[i].winner
93107
end
94108
return ShootoutSet(shootouts,probs,probaux,N,winners)
@@ -140,8 +154,7 @@ mutable struct WorkPrecisionSet
140154
end
141155

142156
function WorkPrecision(prob,alg,abstols,reltols,dts=nothing;
143-
name=nothing,numruns=20,
144-
appxsol=nothing,error_estimate=:final,kwargs...)
157+
name=nothing,appxsol=nothing,error_estimate=:final,numruns=20,seconds=2,kwargs...)
145158
N = length(abstols)
146159
errors = Vector{Float64}(undef,N)
147160
times = Vector{Float64}(undef,N)
@@ -159,15 +172,13 @@ function WorkPrecision(prob,alg,abstols,reltols,dts=nothing;
159172
sol = solve(prob,alg,sol.u,sol.t,sol.k;kwargs...,abstol=abstols[i],
160173
reltol=reltols[i],timeseries_errors=timeseries_errors,
161174
dense_errors = dense_errors) # Compile and get result
162-
GC.gc()
163175
else
164176
sol = solve(prob,alg;kwargs...,abstol=abstols[i],
165177
reltol=reltols[i],dt=dts[i],timeseries_errors=timeseries_errors,
166178
dense_errors = dense_errors) # Compile and get result
167179
sol = solve(prob,alg,sol.u,sol.t,sol.k;kwargs...,abstol=abstols[i],
168180
reltol=reltols[i],dt=dts[i],timeseries_errors=timeseries_errors,
169181
dense_errors = dense_errors) # Compile and get result
170-
GC.gc()
171182
end
172183

173184
if appxsol != nothing
@@ -177,32 +188,41 @@ function WorkPrecision(prob,alg,abstols,reltols,dts=nothing;
177188
errors[i] = mean(sol.errors[error_estimate])
178189
end
179190

180-
t = 0.0
181-
for j in 1:numruns
182-
t_tmp = @elapsed if dts == nothing
183-
solve(prob,alg,sol.u,sol.t,sol.k;kwargs...,
184-
abstol=abstols[i],
185-
reltol=reltols[i],
186-
timeseries_errors=false,
187-
dense_errors = false)
191+
fails = 0
192+
local benchable
193+
@label START
194+
try
195+
benchable = if dts == nothing
196+
@benchmarkable(solve($prob,$alg,$(sol.u),$(sol.t),$(sol.k);
197+
abstol=$(abstols[i]),
198+
reltol=$(reltols[i]),
199+
timeseries_errors = false,
200+
dense_errors = false, $kwargs...))
188201
else
189-
solve(prob,alg,sol.u,sol.t,sol.k;
190-
kwargs...,abstol=abstols[i],
191-
reltol=reltols[i],dt=dts[i],
192-
timeseries_errors=false,
193-
dense_errors = false)
202+
@benchmarkable(solve($prob,$alg,$(sol.u),$(sol.t),$(sol.k);
203+
abstol=$(abstols[i]),
204+
reltol=$(reltols[i]),
205+
dt=$(dts[i]),
206+
timeseries_errors = false,
207+
dense_errors = false, $kwargs...))
194208
end
195-
t += t_tmp
196-
GC.gc()
209+
catch e
210+
# sometimes BenchmarkTools errors with
211+
# `ERROR: syntax: function argument and static parameter names must be distinct`
212+
# so, we are catching that error and try a few times.
213+
fails += 1
214+
fails > 4 && rethrow()
215+
@goto START
197216
end
198-
times[i] = t/numruns
217+
bench = run(benchable, samples=numruns, seconds=seconds)
218+
times[i] = benchtime(bench)
199219
end
200220
return WorkPrecision(prob,abstols,reltols,errors,times,name,N)
201221
end
202222

203223
function WorkPrecisionSet(prob::Union{AbstractODEProblem,AbstractDDEProblem,
204224
AbstractDAEProblem},
205-
abstols,reltols,setups;numruns=20,
225+
abstols,reltols,setups;
206226
print_names=false,names=nothing,appxsol=nothing,
207227
error_estimate=:final,
208228
test_dt=nothing,kwargs...)
@@ -215,17 +235,17 @@ function WorkPrecisionSet(prob::Union{AbstractODEProblem,AbstractDDEProblem,
215235
print_names && println(names[i])
216236
if haskey(setups[i],:dts)
217237
wps[i] = WorkPrecision(prob,setups[i][:alg],abstols,reltols,setups[i][:dts];
218-
numruns=numruns,appxsol=appxsol,
238+
appxsol=appxsol,
219239
error_estimate=error_estimate,
220240
name=names[i],kwargs...,setups[i]...)
221241
else
222242
wps[i] = WorkPrecision(prob,setups[i][:alg],abstols,reltols;
223-
numruns=numruns,appxsol=appxsol,
243+
appxsol=appxsol,
224244
error_estimate=error_estimate,
225245
name=names[i],kwargs...,setups[i]...)
226246
end
227247
end
228-
return WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,nothing,error_estimate,numruns)
248+
return WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,nothing,error_estimate,nothing)
229249
end
230250

231251
@def error_calculation begin
@@ -279,9 +299,9 @@ end
279299
end
280300

281301
function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_dt=nothing;
282-
numruns=20,numruns_error = 20,
302+
numruns_error = 20,
283303
print_names=false,names=nothing,appxsol_setup=nothing,
284-
error_estimate=:final,parallel_type = :none,kwargs...)
304+
error_estimate=:final,parallel_type = :none,numruns=20,seconds=Inf,kwargs...)
285305

286306
timeseries_errors = DiffEqBase.has_analytic(prob.f) && error_estimate TIMESERIES_ERRORS
287307
weak_timeseries_errors = error_estimate WEAK_TIMESERIES_ERRORS
@@ -293,7 +313,6 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
293313
if names == nothing
294314
names = [string(parameterless_type(setups[i][:alg])) for i=1:length(setups)]
295315
end
296-
time_tmp = Vector{Float64}(undef,numruns)
297316

298317
# First calculate all of the errors
299318
if parallel_type == :threads
@@ -333,28 +352,37 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
333352
dense_errors = dense_errors)
334353
end
335354
end
336-
GC.gc()
337355
# Now time it
338356
for k in 1:N
339357
for j in 1:M
340-
for i in 1:numruns
341-
time_tmp[i] = @elapsed if !haskey(setups[k],:dts)
342-
sol = solve(prob,setups[k][:alg];
343-
kwargs...,
344-
abstol=abstols[j],
345-
reltol=reltols[j],
358+
fails = 0
359+
local benchable
360+
@label START
361+
try
362+
benchable = if !haskey(setups[k],:dts)
363+
@benchmarkable(solve($prob,$(setups[k][:alg]);
364+
$kwargs...,
365+
abstol=$(abstols[j]),
366+
reltol=$(reltols[j]),
346367
timeseries_errors=false,
347-
dense_errors = false)
368+
dense_errors = false))
348369
else
349-
sol = solve(prob,setups[k][:alg];
350-
kwargs...,abstol=abstols[j],
351-
reltol=reltols[j],dt=setups[k][:dts][j],
352-
timeseries_errors=false,
353-
dense_errors = false)
370+
@benchmarkable(solve($prob,$(setups[k][:alg]);
371+
$kwargs...,abstol=$(abstols[j]),
372+
reltol=$(reltols[j]),dt=$(setups[k][:dts][j]),
373+
timeseries_errors=false,
374+
dense_errors = false))
354375
end
376+
catch
377+
# sometimes BenchmarkTools errors with
378+
# `ERROR: syntax: function argument and static parameter names must be distinct`
379+
# so, we are catching that error and try a few times.
380+
fails += 1
381+
fails > 4 && rethrow()
382+
@goto START
355383
end
356-
times[j,k] = mean(time_tmp)
357-
GC.gc()
384+
bench = run(benchable, samples=numruns, seconds=seconds)
385+
times[j,k] = benchtime(bench)
358386
end
359387
end
360388

0 commit comments

Comments
 (0)