Skip to content

Commit 6b3ba9c

Browse files
Merge pull request #124 from nathanaelbosch/save_stats_and_dts
Save `dts` and `sol.stats` into `WorkPrecision`
2 parents f23cf22 + 66823ba commit 6b3ba9c

File tree

3 files changed

+48
-40
lines changed

3 files changed

+48
-40
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ jobs:
1515
- Core
1616
version:
1717
- '1'
18-
- '1.6'
1918
steps:
2019
- uses: actions/checkout@v4
2120
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,28 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1818

1919
[compat]
2020
BVProblemLibrary = "0.1"
21-
BoundaryValueDiffEq = "4, 5"
21+
BoundaryValueDiffEq = "5"
2222
DDEProblemLibrary = "0.1"
2323
DelayDiffEq = "5.20"
2424
DiffEqBase = "6.94.4"
2525
DiffEqNoiseProcess = "5.0"
26-
Distributed = "1.6"
27-
LinearAlgebra = "1.6"
28-
Logging = "1.6"
26+
Distributed = "1.9"
27+
LinearAlgebra = "1.9"
28+
Logging = "1.9"
2929
NLsolve = "4.2"
3030
NonlinearSolve = "1, 2"
3131
ODEProblemLibrary = "0.1"
3232
OrdinaryDiffEq = "6"
3333
ParameterizedFunctions = "5"
34-
RecipesBase = "0.7, 0.8, 1.0"
34+
RecipesBase = "1"
3535
RecursiveArrayTools = "2"
36-
RootedTrees = "1, 2"
36+
RootedTrees = "2"
3737
SDEProblemLibrary = "0.1"
38-
SciMLBase = "1.74, 2"
38+
SciMLBase = "2"
3939
Statistics = "1"
4040
StochasticDelayDiffEq = "1"
4141
StochasticDiffEq = "6"
42-
julia = "1.6"
42+
julia = "1.9"
4343

4444
[extras]
4545
BVProblemLibrary = "ded0fc24-dfea-4565-b1d9-79c027d14d84"

src/benchmark.jl

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ mutable struct WorkPrecision
161161
reltols::Any
162162
errors::Any
163163
times::Any
164+
dts::Any
165+
stats::Any
164166
name::Any
165167
N::Int
166168
end
@@ -183,6 +185,7 @@ function WorkPrecision(prob, alg, abstols, reltols, dts = nothing;
183185
N = length(abstols)
184186
errors = Vector{Float64}(undef, N)
185187
times = Vector{Float64}(undef, N)
188+
stats = Vector{Any}(undef, N)
186189
if name === nothing
187190
name = "WP-Alg"
188191
end
@@ -210,6 +213,8 @@ function WorkPrecision(prob, alg, abstols, reltols, dts = nothing;
210213
dense_errors = dense_errors)
211214
end
212215

216+
stats[i] = sol.stats
217+
213218
if haskey(kwargs, :prob_choice)
214219
cur_appxsol = appxsol[kwargs[:prob_choice]]
215220
elseif prob isa AbstractArray
@@ -270,7 +275,7 @@ function WorkPrecision(prob, alg, abstols, reltols, dts = nothing;
270275
end
271276
end
272277
end
273-
return WorkPrecision(prob, abstols, reltols, errors, times, name, N)
278+
return WorkPrecision(prob, abstols, reltols, errors, times, dts, stats, name, N)
274279
end
275280

276281
# Work precision information for a BVP
@@ -280,6 +285,7 @@ function WorkPrecision(prob::AbstractBVProblem, alg, abstols, reltols, dts = not
280285
N = length(abstols)
281286
errors = Vector{Float64}(undef, N)
282287
times = Vector{Float64}(undef, N)
288+
stats = Vector{Any}(undef, N)
283289
if name === nothing
284290
name = "WP-Alg"
285291
end
@@ -307,6 +313,8 @@ function WorkPrecision(prob::AbstractBVProblem, alg, abstols, reltols, dts = not
307313
dense_errors = dense_errors)
308314
end
309315

316+
stats[i] = sol.stats
317+
310318
if haskey(kwargs, :prob_choice)
311319
cur_appxsol = appxsol[kwargs[:prob_choice]]
312320
elseif prob isa AbstractArray
@@ -367,14 +375,15 @@ function WorkPrecision(prob::AbstractBVProblem, alg, abstols, reltols, dts = not
367375
end
368376
end
369377
end
370-
return WorkPrecision(prob, abstols, reltols, errors, times, name, N)
378+
return WorkPrecision(prob, abstols, reltols, errors, times, dts, stats, name, N)
371379
end
372380

373381
# Work precision information for a nonlinear problem.
374382
function WorkPrecision(prob::NonlinearProblem, alg, abstols, reltols, dts = nothing; name = nothing, appxsol = nothing, error_estimate = :l2, numruns = 20, seconds = 2, kwargs...)
375383
N = length(abstols)
376384
errors = Vector{Float64}(undef, N)
377385
times = Vector{Float64}(undef, N)
386+
stats = Vector{Any}(undef, N)
378387
if name === nothing
379388
name = "WP-Alg"
380389
end
@@ -391,9 +400,11 @@ function WorkPrecision(prob::NonlinearProblem, alg, abstols, reltols, dts = noth
391400
for i in 1:N
392401
sol = solve(_prob, alg; kwargs..., abstol = abstols[i], reltol = reltols[i])
393402

403+
stats[i] = sol.stats
404+
394405
if error_estimate == :l2
395406
if isnothing(appxsol)
396-
errors[i] = sqrt(sum(abs2, sol.resid))
407+
errors[i] = sqrt(sum(abs2, sol.resid))
397408
else
398409
errors[i] = sqrt(sum(abs2, sol .- appxsol))
399410
end
@@ -419,7 +430,7 @@ function WorkPrecision(prob::NonlinearProblem, alg, abstols, reltols, dts = noth
419430
end
420431
end
421432
end
422-
return WorkPrecision(prob, abstols, reltols, errors, times, name, N)
433+
return WorkPrecision(prob, abstols, reltols, errors, times, dts, stats, name, N)
423434
end
424435

425436
function WorkPrecisionSet(prob,
@@ -533,25 +544,25 @@ function WorkPrecisionSet(prob::AbstractRODEProblem, abstols, reltols, setups,
533544
weak_dense_errors = weak_dense_errors)
534545
for sim in sol_k] for sol_k in _solutions_k]
535546
if error_estimate WEAK_ERRORS
536-
errors = [[solutions[j][i].weak_errors[error_estimate] for i in 1:M] for j in 1:N]
547+
errors = [[solutions[j][i].weak_errors for i in 1:M] for j in 1:N]
537548
else
538-
errors = [[solutions[j][i].error_means[error_estimate] for i in 1:M] for j in 1:N]
549+
errors = [[solutions[j][i].error_means for i in 1:M] for j in 1:N]
539550
end
540551

541552
local _sol
542553

543554
# Now time it
555+
_abstols = [get(setups[k], :abstols, abstols) for k in 1:N]
556+
_reltols = [get(setups[k], :reltols, reltols) for k in 1:N]
557+
_dts = [get(setups[k], :dts, zeros(length(_abstols))) for k in 1:N]
544558
for k in 1:N
545559
# precompile
546560
GC.gc()
547-
_abstols = get(setups[k], :abstols, abstols)
548-
_reltols = get(setups[k], :reltols, reltols)
549-
_dts = get(setups[k], :dts, zeros(length(_abstols)))
550561
filtered_setup = filter(p -> p.first in DiffEqBase.allowedkeywords, setups[k])
551562

552563
_sol = solve(prob, setups[k][:alg];
553-
kwargs..., filtered_setup..., abstol = _abstols[1],
554-
reltol = _reltols[1], dt = _dts[1],
564+
kwargs..., filtered_setup..., abstol = _abstols[k][1],
565+
reltol = _reltols[k][1], dt = _dts[k][1],
555566
timeseries_errors = false,
556567
dense_errors = false)
557568
x = isempty(_sol.t) ? 0 : round(Int, mean(_sol.t) - sum(_sol.t) / length(_sol.t))
@@ -560,8 +571,8 @@ function WorkPrecisionSet(prob::AbstractRODEProblem, abstols, reltols, setups,
560571
for i in 1:numruns
561572
time_tmp[i] = @elapsed sol = solve(prob, setups[k][:alg];
562573
kwargs..., filtered_setup...,
563-
abstol = _abstols[j],
564-
reltol = _reltols[j], dt = _dts[j],
574+
abstol = _abstols[k][j],
575+
reltol = _reltols[k][j], dt = _dts[k][j],
565576
timeseries_errors = false,
566577
dense_errors = false)
567578
end
@@ -570,7 +581,8 @@ function WorkPrecisionSet(prob::AbstractRODEProblem, abstols, reltols, setups,
570581
end
571582
end
572583

573-
wps = [WorkPrecision(prob, abstols, reltols, errors[i], times[:, i], names[i], N)
584+
stats = nothing
585+
wps = [WorkPrecision(prob, _abstols[i], _reltols[i], errors[i], times[:, i], _dts[i], stats, names[i], N)
574586
for i in 1:N]
575587
WorkPrecisionSet(wps, N, abstols, reltols, prob, setups, names, error_estimate,
576588
numruns_error)
@@ -598,18 +610,18 @@ function WorkPrecisionSet(prob::AbstractEnsembleProblem, abstols, reltols, setup
598610
time_tmp = Vector{Float64}(undef, numruns)
599611

600612
# First calculate all of the errors
613+
_abstols = [get(setups[k], :abstols, abstols) for k in 1:N]
614+
_reltols = [get(setups[k], :reltols, reltols) for k in 1:N]
615+
_dts = [get(setups[k], :dts, zeros(length(_abstols))) for k in 1:N]
601616
for k in 1:N
602-
_abstols = get(setups[k], :abstols, abstols)
603-
_reltols = get(setups[k], :reltols, reltols)
604-
_dts = get(setups[k], :dts, zeros(length(_abstols)))
605617
filtered_setup = filter(p -> p.first in DiffEqBase.allowedkeywords, setups[k])
606618

607619
for j in 1:M
608620
sol = solve(prob, setups[k][:alg], ensemblealg;
609621
filtered_setup...,
610-
abstol = _abstols[j],
611-
reltol = _reltols[j],
612-
dt = _dts[j],
622+
abstol = _abstols[k][j],
623+
reltol = _reltols[k][j],
624+
dt = _dts[k][j],
613625
timeseries_errors = false,
614626
dense_errors = false,
615627
trajectories = Int(trajectories), kwargs...)
@@ -648,16 +660,13 @@ function WorkPrecisionSet(prob::AbstractEnsembleProblem, abstols, reltols, setup
648660
for k in 1:N
649661
# precompile
650662
GC.gc()
651-
_abstols = get(setups[k], :abstols, abstols)
652-
_reltols = get(setups[k], :reltols, reltols)
653-
_dts = get(setups[k], :dts, zeros(length(_abstols)))
654663
filtered_setup = filter(p -> p.first in DiffEqBase.allowedkeywords, setups[k])
655664

656665
_sol = solve(prob, setups[k][:alg], ensemblealg;
657666
filtered_setup...,
658-
abstol = _abstols[1],
659-
reltol = _reltols[1],
660-
dt = _dts[1],
667+
abstol = _abstols[k][1],
668+
reltol = _reltols[k][1],
669+
dt = _dts[k][1],
661670
timeseries_errors = false,
662671
dense_errors = false,
663672
trajectories = Int(trajectories), kwargs...)
@@ -667,9 +676,9 @@ function WorkPrecisionSet(prob::AbstractEnsembleProblem, abstols, reltols, setup
667676
for i in 1:numruns
668677
time_tmp[i] = @elapsed sol = solve(prob, setups[k][:alg], ensemblealg;
669678
filtered_setup...,
670-
abstol = _abstols[j],
671-
reltol = _reltols[j],
672-
dt = _dts[j],
679+
abstol = _abstols[k][j],
680+
reltol = _reltols[k][j],
681+
dt = _dts[k][j],
673682
timeseries_errors = false,
674683
dense_errors = false,
675684
trajectories = Int(trajectories),
@@ -679,8 +688,8 @@ function WorkPrecisionSet(prob::AbstractEnsembleProblem, abstols, reltols, setup
679688
GC.gc()
680689
end
681690
end
682-
683-
wps = [WorkPrecision(prob, abstols, reltols, errors[i], times[:, i], names[i], N)
691+
stats = nothing
692+
wps = [WorkPrecision(prob, _abstols[i], _reltols[i], errors[i], times[:, i], _dts[i], stats, names[i], N)
684693
for i in 1:N]
685694
WorkPrecisionSet(wps, N, abstols, reltols, prob, setups, names, error_estimate,
686695
Int(trajectories))

0 commit comments

Comments
 (0)