Skip to content

Commit 4d6534a

Browse files
Merge pull request #80 from SciML/individual
Allow individualized work-precision tolerances and lengths
2 parents 8a15ddb + b2b5f98 commit 4d6534a

File tree

2 files changed

+77
-123
lines changed

2 files changed

+77
-123
lines changed

src/benchmark.jl

Lines changed: 71 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -244,17 +244,13 @@ function WorkPrecisionSet(prob,
244244
end
245245
for i in 1:N
246246
print_names && println(names[i])
247-
if haskey(setups[i],:dts)
248-
wps[i] = WorkPrecision(prob,setups[i][:alg],abstols,reltols,setups[i][:dts];
249-
appxsol=appxsol,
250-
error_estimate=error_estimate,
251-
name=names[i],kwargs...,setups[i]...)
252-
else
253-
wps[i] = WorkPrecision(prob,setups[i][:alg],abstols,reltols;
254-
appxsol=appxsol,
255-
error_estimate=error_estimate,
256-
name=names[i],kwargs...,setups[i]...)
257-
end
247+
_abstols = get(setups[i],:abstols,abstols)
248+
_reltols = get(setups[i],:reltols,reltols)
249+
_dts = get(setups[i],:dts,nothing)
250+
wps[i] = WorkPrecision(prob,setups[i][:alg],_abstols,_reltols,_dts;
251+
appxsol=appxsol,
252+
error_estimate=error_estimate,
253+
name=names[i],kwargs...,setups[i]...)
258254
end
259255
return WorkPrecisionSet(wps,N,abstols,reltols,prob,setups,names,error_estimate,nothing)
260256
end
@@ -272,36 +268,26 @@ end
272268
end
273269

274270
# Get a cache
275-
if !haskey(setups[1],:dts)
276-
sol = solve(_prob,setups[1][:alg];
277-
kwargs...,setups[1]...,
278-
abstol=abstols[1],
279-
reltol=reltols[1],
280-
timeseries_errors=false,
281-
dense_errors = false)
282-
else
283-
sol = solve(_prob,setups[1][:alg];
284-
kwargs...,setups[1]...,abstol=abstols[1],
285-
reltol=reltols[1],dt=setups[1][:dts][1],
286-
timeseries_errors=false,
287-
dense_errors = false)
288-
end
271+
_abstols = get(setups[1],:abstols,abstols)
272+
_reltols = get(setups[1],:reltols,reltols)
273+
_dts = get(setups[1],:dts,zeros(length(_abstols)))
274+
275+
sol = solve(_prob,setups[1][:alg];
276+
kwargs...,setups[1]...,abstol=_abstols[1],
277+
reltol=_reltols[1],dt=_dts[1],
278+
timeseries_errors=false,
279+
dense_errors = false)
289280

290281
for j in 1:M, k in 1:N
291-
if !haskey(setups[k],:dts)
292-
sol = solve(_prob,setups[k][:alg];
293-
kwargs...,setups[k]...,
294-
abstol=abstols[j],
295-
reltol=reltols[j],
296-
timeseries_errors=timeseries_errors,
297-
dense_errors = dense_errors)
298-
else
299-
sol = solve(_prob,setups[k][:alg];
300-
kwargs...,setups[k]...,abstol=abstols[j],
301-
reltol=reltols[j],dt=setups[k][:dts][j],
302-
timeseries_errors=timeseries_errors,
303-
dense_errors = dense_errors)
304-
end
282+
_abstols = get(setups[k],:abstols,abstols)
283+
_reltols = get(setups[k],:reltols,reltols)
284+
_dts = get(setups[k],:dts,zeros(length(_abstols)))
285+
286+
sol = solve(_prob,setups[k][:alg];
287+
kwargs...,setups[k]...,abstol=_abstols[j],
288+
reltol=_reltols[j],dt=_dts[j],
289+
timeseries_errors=timeseries_errors,
290+
dense_errors = dense_errors)
305291
DiffEqBase.has_analytic(prob.f) ? err_sol = sol : err_sol = appxtrue(sol,true_sol)
306292
tmp_solutions[i,j,k] = err_sol
307293
end
@@ -352,38 +338,24 @@ function WorkPrecisionSet(prob::AbstractRODEProblem,abstols,reltols,setups,test_
352338
for k in 1:N
353339
# precompile
354340
GC.gc()
355-
if !haskey(setups[k],:dts)
356-
_sol = solve(prob,setups[k][:alg];
357-
kwargs...,setups[k]...,
358-
abstol=abstols[1],
359-
reltol=reltols[1],
360-
timeseries_errors=false,
361-
dense_errors = false)
362-
else
363-
_sol = solve(prob,setups[k][:alg];
364-
kwargs...,setups[k]...,abstol=abstols[1],
365-
reltol=reltols[1],dt=setups[k][:dts][1],
366-
timeseries_errors=false,
367-
dense_errors = false)
368-
end
341+
_abstols = get(setups[k],:abstols,abstols)
342+
_reltols = get(setups[k],:reltols,reltols)
343+
_dts = get(setups[k],:dts,zeros(length(_abstols)))
344+
345+
_sol = solve(prob,setups[k][:alg];
346+
kwargs...,setups[k]...,abstol=_abstols[1],
347+
reltol=_reltols[1],dt=_dts[1],
348+
timeseries_errors=false,
349+
dense_errors = false)
369350
x = isempty(_sol.t) ? 0 : round(Int,mean(_sol.t) - sum(_sol.t)/length(_sol.t))
370351
GC.gc()
371352
for j in 1:M
372353
for i in 1:numruns
373-
time_tmp[i] = @elapsed if !haskey(setups[k],:dts)
374-
sol = solve(prob,setups[k][:alg];
375-
kwargs...,setups[k]...,
376-
abstol=abstols[j],
377-
reltol=reltols[j],
378-
timeseries_errors=false,
379-
dense_errors = false)
380-
else
381-
sol = solve(prob,setups[k][:alg];
382-
kwargs...,setups[k]...,abstol=abstols[j],
383-
reltol=reltols[j],dt=setups[k][:dts][j],
384-
timeseries_errors=false,
385-
dense_errors = false)
386-
end
354+
time_tmp[i] = @elapsed sol = solve(prob,setups[k][:alg];
355+
kwargs...,setups[k]...,abstol=_abstols[j],
356+
reltol=_reltols[j],dt=_dts[j],
357+
timeseries_errors=false,
358+
dense_errors = false)
387359
end
388360
times[j,k] = mean(time_tmp) + x
389361
GC.gc()
@@ -416,25 +388,18 @@ function WorkPrecisionSet(prob::AbstractEnsembleProblem,abstols,reltols,setups,t
416388

417389
# First calculate all of the errors
418390
for k in 1:N
391+
_abstols = get(setups[k],:abstols,abstols)
392+
_reltols = get(setups[k],:reltols,reltols)
393+
_dts = get(setups[k],:dts,zeros(length(_abstols)))
419394
for j in 1:M
420-
if !haskey(setups[1],:dts)
421-
sol = solve(prob,setups[k][:alg],ensemblealg;
422-
setups[k]...,
423-
abstol=abstols[j],
424-
reltol=reltols[j],
425-
timeseries_errors=false,
426-
dense_errors = false,
427-
trajectories=Int(trajectories),kwargs...)
428-
else
429-
sol = solve(prob,setups[k][:alg],ensemblealg;
430-
setups[k]...,
431-
abstol=abstols[j],
432-
reltol=reltols[j],
433-
dt=setups[k][:dts][j],
434-
timeseries_errors=false,
435-
dense_errors = false,
436-
trajectories=Int(trajectories),kwargs...)
437-
end
395+
sol = solve(prob,setups[k][:alg],ensemblealg;
396+
setups[k]...,
397+
abstol=_abstols[j],
398+
reltol=_reltols[j],
399+
dt=_dts[j],
400+
timeseries_errors=false,
401+
dense_errors = false,
402+
trajectories=Int(trajectories),kwargs...)
438403
solutions[j,k] = sol
439404
end
440405
@info "$(setups[k][:alg]) ($k/$N)"
@@ -449,7 +414,7 @@ function WorkPrecisionSet(prob::AbstractEnsembleProblem,abstols,reltols,setups,t
449414
errors = [[LinearAlgebra.norm(Statistics.mean(solutions[i,j] .- expected_value))
450415
for i in 1:M] for j in 1:N]
451416
else
452-
error("Error estimate $error_estimate is not implemented yet.")
417+
error("Error estimate $error_estimate is not implemented yet.")
453418
end
454419
else
455420
sol = solve(prob,appxsol_setup[:alg],ensemblealg;kwargs...,appxsol_setup...,
@@ -467,46 +432,30 @@ function WorkPrecisionSet(prob::AbstractEnsembleProblem,abstols,reltols,setups,t
467432
for k in 1:N
468433
# precompile
469434
GC.gc()
470-
if !haskey(setups[1],:dts)
471-
_sol = solve(prob,setups[k][:alg],ensemblealg;
472-
setups[k]...,
473-
abstol=abstols[1],
474-
reltol=reltols[1],
475-
timeseries_errors=false,
476-
dense_errors = false,
477-
trajectories=Int(trajectories),kwargs...)
478-
else
479-
_sol = solve(prob,setups[k][:alg],ensemblealg;
480-
setups[k]...,
481-
abstol=abstols[1],
482-
reltol=reltols[1],
483-
dt=setups[k][:dts][1],
484-
timeseries_errors=false,
485-
dense_errors = false,
486-
trajectories=Int(trajectories),kwargs...)
487-
end
435+
_abstols = get(setups[k],:abstols,abstols)
436+
_reltols = get(setups[k],:reltols,reltols)
437+
_dts = get(setups[k],:dts,zeros(length(_abstols)))
438+
439+
_sol = solve(prob,setups[k][:alg],ensemblealg;
440+
setups[k]...,
441+
abstol=_abstols[1],
442+
reltol=_reltols[1],
443+
dt=_dts[1],
444+
timeseries_errors=false,
445+
dense_errors = false,
446+
trajectories=Int(trajectories),kwargs...)
488447
#x = isempty(_sol.t) ? 0 : round(Int,mean(_sol.t) - sum(_sol.t)/length(_sol.t))
489448
GC.gc()
490449
for j in 1:M
491450
for i in 1:numruns
492-
time_tmp[i] = @elapsed if !haskey(setups[k],:dts)
493-
sol = solve(prob,setups[k][:alg],ensemblealg;
494-
setups[k]...,
495-
abstol=abstols[j],
496-
reltol=reltols[j],
497-
timeseries_errors=false,
498-
dense_errors = false,
499-
trajectories=Int(trajectories),kwargs...)
500-
else
501-
sol = solve(prob,setups[k][:alg],ensemblealg;
502-
setups[k]...,
503-
abstol=abstols[j],
504-
reltol=reltols[j],
505-
dt=setups[k][:dts][j],
506-
timeseries_errors=false,
507-
dense_errors = false,
508-
trajectories=Int(trajectories),kwargs...)
509-
end
451+
time_tmp[i] = @elapsed sol = solve(prob,setups[k][:alg],ensemblealg;
452+
setups[k]...,
453+
abstol=_abstols[j],
454+
reltol=_reltols[j],
455+
dt=_dts[j],
456+
timeseries_errors=false,
457+
dense_errors = false,
458+
trajectories=Int(trajectories),kwargs...)
510459
end
511460
times[j,k] = mean(time_tmp) #+ x
512461
GC.gc()

test/benchmark_tests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,18 @@ abstols = 1 ./10 .^(3:7)
7272
reltols = 1 ./10 .^(0:4)
7373

7474
setups = [Dict(:alg=>DP5())
75-
Dict(:alg=>Tsit5())]
75+
Dict(:alg=>Tsit5(),
76+
:abstols => 1 ./10 .^(4:7),
77+
:reltols => 1 ./10 .^(1:4))]
7678

7779
sol = solve(prob,Vern7(),abstol=1/10^14,reltol=1/10^14)
7880
test_sol1 = TestSolution(sol)
7981
println("Test DP5 and Tsit5")
8082
wp = WorkPrecisionSet(prob,abstols,reltols,setups;save_everystep=false)
8183

84+
@test length(wp[1]) == 5
85+
@test length(wp[2]) == 4
86+
8287
function lotka(du,u,p,t)
8388
du[1] = 1.5 * u[1] - u[1]*u[2]
8489
du[2] = -3 * u[2] + u[1]*u[2]

0 commit comments

Comments
 (0)