Skip to content

Commit b1e61da

Browse files
Merge pull request #287 from CliMA/ck/filt_funcs
Allow benchmark_step to benchmark targeted functions
2 parents cf71922 + e3870d2 commit b1e61da

File tree

2 files changed

+52
-18
lines changed

2 files changed

+52
-18
lines changed

ext/ClimaTimeSteppersBenchmarkToolsExt.jl

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,45 +37,76 @@ n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict(
3737
"step!" => 1,
3838
)
3939

40+
function maybe_push!(trials₀, name, f!, args, kwargs, only)
41+
if isnothing(only) || name in only
42+
trials₀[name] = get_trial(f!, args, name; kwargs...)
43+
end
44+
end
45+
46+
const allowed_names =
47+
["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "post_explicit!", "post_implicit!", "step!"]
4048

4149
"""
4250
benchmark_step(
4351
integrator::DistributedODEIntegrator,
4452
device::ClimaComms.AbstractDevice;
45-
with_cu_prof = :bfrofile, # [:bprofile, :profile]
46-
trace = false
53+
with_cu_prof::Symbol = :bfrofile, # [:bprofile, :profile]
54+
trace::Bool = false,
55+
crop::Bool = false,
56+
only::Union{Nothing, Vector{String}} = nothing,
4757
)
4858
49-
Benchmark a DistributedODEIntegrator
59+
Benchmark a DistributedODEIntegrator given:
60+
- `integrator` the `DistributedODEIntegrator`.
61+
- `device` the `ClimaComms` device.
62+
- `with_cu_prof`, `:profile` or `:bprofile`, to call `CUDA.@profile` or `CUDA.@bprofile` respectively.
63+
- `trace`, Bool passed to `CUDA.@profile` (see CUDA docs)
64+
- `crop`, Bool indicating whether or not to crop the `CUDA.@profile` printed table.
65+
- `only, list of functions to benchmarks (benchmark all by default)
66+
67+
`only` may contain:
68+
- "Wfact"
69+
- "ldiv!"
70+
- "T_imp!"
71+
- "T_exp_T_lim!"
72+
- "lim!"
73+
- "dss!"
74+
- "post_explicit!"
75+
- "post_implicit!"
76+
- "step!"
5077
"""
5178
function CTS.benchmark_step(
5279
integrator::CTS.DistributedODEIntegrator,
5380
device::ClimaComms.AbstractDevice;
54-
with_cu_prof = :bprofile,
55-
trace = false,
56-
crop = false,
81+
with_cu_prof::Symbol = :bprofile,
82+
trace::Bool = false,
83+
crop::Bool = false,
84+
only::Union{Nothing, Vector{String}} = nothing,
5785
)
5886
(; u, p, t, dt, sol, alg) = integrator
5987
(; f) = sol.prob
6088
if f isa CTS.ClimaODEFunction
89+
if !isnothing(only)
90+
@assert all(x -> x in allowed_names, only) "Allowed names in `only` are: $allowed_names"
91+
end
6192

6293
W = get_W(integrator)
6394
X = similar(u)
6495
Xlim = similar(u)
6596
@. X = u
6697
@. Xlim = u
6798
trials₀ = OrderedCollections.OrderedDict()
68-
99+
kwargs = (; device, with_cu_prof, trace, crop)
69100
#! format: off
70-
trials₀["Wfact"] = get_trial(wfact_fun(integrator), (W, u, p, dt, t), "Wfact", device; with_cu_prof, trace, crop);
71-
trials₀["ldiv!"] = get_trial(LA.ldiv!, (X, W, u), "ldiv!", device; with_cu_prof, trace, crop);
72-
trials₀["T_imp!"] = get_trial(implicit_fun(integrator), implicit_args(integrator), "T_imp!", device; with_cu_prof, trace, crop);
73-
trials₀["T_exp_T_lim!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "T_exp_T_lim!", device; with_cu_prof, trace, crop);
74-
trials₀["lim!"] = get_trial(f.lim!, (Xlim, p, t, u), "lim!", device; with_cu_prof, trace, crop);
75-
trials₀["dss!"] = get_trial(f.dss!, (u, p, t), "dss!", device; with_cu_prof, trace, crop);
76-
trials₀["post_explicit!"] = get_trial(f.post_explicit!, (u, p, t), "post_explicit!", device; with_cu_prof, trace, crop);
77-
trials₀["post_implicit!"] = get_trial(f.post_implicit!, (u, p, t), "post_implicit!", device; with_cu_prof, trace, crop);
78-
trials₀["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!", device; with_cu_prof, trace, crop);
101+
maybe_push!(trials₀, "Wfact", wfact_fun(integrator), (W, u, p, dt, t), kwargs, only)
102+
maybe_push!(trials₀, "ldiv!", LA.ldiv!, (X, W, u), kwargs, only)
103+
maybe_push!(trials₀, "T_imp!", implicit_fun(integrator), implicit_args(integrator), kwargs, only)
104+
maybe_push!(trials₀, "T_exp_T_lim!", remaining_fun(integrator), remaining_args(integrator), kwargs, only)
105+
maybe_push!(trials₀, "lim!", f.lim!, (Xlim, p, t, u), kwargs, only)
106+
maybe_push!(trials₀, "dss!", f.dss!, (u, p, t), kwargs, only)
107+
maybe_push!(trials₀, "post_explicit!", f.post_explicit!, (u, p, t), kwargs, only)
108+
maybe_push!(trials₀, "post_implicit!", f.post_implicit!, (u, p, t), kwargs, only)
109+
maybe_push!(trials₀, "step!", SciMLBase.step!, (integrator, ), kwargs, only)
79110
#! format: on
80111

81112
trials = OrderedCollections.OrderedDict()
@@ -92,6 +123,9 @@ function CTS.benchmark_step(
92123
table_summary[k] = get_summary(trials[k], trials["step!"])
93124
end
94125

126+
if !isnothing(only)
127+
@warn "Percentages are only based on $only, pass `only = nothing` for accurately reported percentages"
128+
end
95129
tabulate_summary(table_summary; n_calls_per_step)
96130

97131
return (; table_summary, trials)

ext/benchmark_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ function tabulate_summary(summary; n_calls_per_step)
4848
)
4949
end
5050

51-
get_trial(f::Nothing, args, name, device; with_cu_prof = :bprofile, trace = false, crop = false) = nothing
52-
function get_trial(f, args, name, device; with_cu_prof = :bprofile, trace = false, crop = false)
51+
get_trial(f::Nothing, args, name; device, with_cu_prof = :bprofile, trace = false, crop = false) = nothing
52+
function get_trial(f, args, name; device, with_cu_prof = :bprofile, trace = false, crop = false)
5353
f(args...) # compile first
5454
b = if device isa ClimaComms.CUDADevice
5555
BenchmarkTools.@benchmarkable CUDA.@sync $f($(args)...)

0 commit comments

Comments
 (0)