@@ -37,45 +37,76 @@ n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict(
3737 " step!" => 1 ,
3838)
3939
40+ function maybe_push! (trials₀, name, f!, args, kwargs, only)
41+ if isnothing (only) || name in only
42+ trials₀[name] = get_trial (f!, args, name; kwargs... )
43+ end
44+ end
45+
46+ const allowed_names =
47+ [" Wfact" , " ldiv!" , " T_imp!" , " T_exp_T_lim!" , " lim!" , " dss!" , " post_explicit!" , " post_implicit!" , " step!" ]
4048
4149"""
4250 benchmark_step(
4351 integrator::DistributedODEIntegrator,
4452 device::ClimaComms.AbstractDevice;
45- with_cu_prof = :bfrofile, # [:bprofile, :profile]
46- trace = false
53+ with_cu_prof::Symbol = :bfrofile, # [:bprofile, :profile]
54+ trace::Bool = false,
55+ crop::Bool = false,
56+ only::Union{Nothing, Vector{String}} = nothing,
4757 )
4858
49- Benchmark a DistributedODEIntegrator
59+ Benchmark a DistributedODEIntegrator given:
60+ - `integrator` the `DistributedODEIntegrator`.
61+ - `device` the `ClimaComms` device.
62+ - `with_cu_prof`, `:profile` or `:bprofile`, to call `CUDA.@profile` or `CUDA.@bprofile` respectively.
63+ - `trace`, Bool passed to `CUDA.@profile` (see CUDA docs)
64+ - `crop`, Bool indicating whether or not to crop the `CUDA.@profile` printed table.
65+ - `only, list of functions to benchmarks (benchmark all by default)
66+
67+ `only` may contain:
68+ - "Wfact"
69+ - "ldiv!"
70+ - "T_imp!"
71+ - "T_exp_T_lim!"
72+ - "lim!"
73+ - "dss!"
74+ - "post_explicit!"
75+ - "post_implicit!"
76+ - "step!"
5077"""
5178function CTS. benchmark_step (
5279 integrator:: CTS.DistributedODEIntegrator ,
5380 device:: ClimaComms.AbstractDevice ;
54- with_cu_prof = :bprofile ,
55- trace = false ,
56- crop = false ,
81+ with_cu_prof:: Symbol = :bprofile ,
82+ trace:: Bool = false ,
83+ crop:: Bool = false ,
84+ only:: Union{Nothing, Vector{String}} = nothing ,
5785)
5886 (; u, p, t, dt, sol, alg) = integrator
5987 (; f) = sol. prob
6088 if f isa CTS. ClimaODEFunction
89+ if ! isnothing (only)
90+ @assert all (x -> x in allowed_names, only) " Allowed names in `only` are: $allowed_names "
91+ end
6192
6293 W = get_W (integrator)
6394 X = similar (u)
6495 Xlim = similar (u)
6596 @. X = u
6697 @. Xlim = u
6798 trials₀ = OrderedCollections. OrderedDict ()
68-
99+ kwargs = (; device, with_cu_prof, trace, crop)
69100# ! format: off
70- trials₀[ " Wfact " ] = get_trial ( wfact_fun (integrator), (W, u, p, dt, t), " Wfact " , device; with_cu_prof, trace, crop);
71- trials₀[ " ldiv! " ] = get_trial ( LA. ldiv!, (X, W, u), " ldiv! " , device; with_cu_prof, trace, crop);
72- trials₀[ " T_imp! " ] = get_trial ( implicit_fun (integrator), implicit_args (integrator), " T_imp! " , device; with_cu_prof, trace, crop);
73- trials₀[ " T_exp_T_lim!" ] = get_trial ( remaining_fun (integrator), remaining_args (integrator), " T_exp_T_lim! " , device; with_cu_prof, trace, crop);
74- trials₀[ " lim!" ] = get_trial ( f. lim!, (Xlim, p, t, u), " lim! " , device; with_cu_prof, trace, crop);
75- trials₀[ " dss! " ] = get_trial ( f. dss!, (u, p, t), " dss! " , device; with_cu_prof, trace, crop);
76- trials₀[ " post_explicit!" ] = get_trial ( f. post_explicit!, (u, p, t), " post_explicit! " , device; with_cu_prof, trace, crop);
77- trials₀[ " post_implicit!" ] = get_trial ( f. post_implicit!, (u, p, t), " post_implicit! " , device; with_cu_prof, trace, crop);
78- trials₀[ " step! " ] = get_trial ( SciMLBase. step!, (integrator, ), " step! " , device; with_cu_prof, trace, crop);
101+ maybe_push! (trials₀, " Wfact " , wfact_fun (integrator), (W, u, p, dt, t), kwargs, only)
102+ maybe_push! (trials₀, " ldiv! " , LA. ldiv!, (X, W, u), kwargs, only)
103+ maybe_push! (trials₀, " T_imp! " , implicit_fun (integrator), implicit_args (integrator), kwargs, only)
104+ maybe_push! ( trials₀, " T_exp_T_lim!" , remaining_fun (integrator), remaining_args (integrator), kwargs, only)
105+ maybe_push! ( trials₀, " lim!" , f. lim!, (Xlim, p, t, u), kwargs, only)
106+ maybe_push! (trials₀, " dss! " , f. dss!, (u, p, t), kwargs, only)
107+ maybe_push! ( trials₀, " post_explicit!" , f. post_explicit!, (u, p, t), kwargs, only)
108+ maybe_push! ( trials₀, " post_implicit!" , f. post_implicit!, (u, p, t), kwargs, only)
109+ maybe_push! (trials₀, " step! " , SciMLBase. step!, (integrator, ), kwargs, only)
79110# ! format: on
80111
81112 trials = OrderedCollections. OrderedDict ()
@@ -92,6 +123,9 @@ function CTS.benchmark_step(
92123 table_summary[k] = get_summary (trials[k], trials[" step!" ])
93124 end
94125
126+ if ! isnothing (only)
127+ @warn " Percentages are only based on $only , pass `only = nothing` for accurately reported percentages"
128+ end
95129 tabulate_summary (table_summary; n_calls_per_step)
96130
97131 return (; table_summary, trials)
0 commit comments