Skip to content

Commit cb3adbd

Browse files
Per reviewer suggestion change benchmark_base::optinal_ref to optional_ptr
benchmark_base::get_printer now return optional_ptr, i.e. std::optional<nvbench::printer_base *>. All use sites are updated accordingly.
1 parent dfb5045 commit cb3adbd

File tree

8 files changed

+51
-52
lines changed

8 files changed

+51
-52
lines changed

nvbench/benchmark_base.cuh

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
#include <nvbench/state.cuh>
3434
#include <nvbench/stopping_criterion.cuh>
3535

36-
#include <functional> // reference_wrapper, ref
3736
#include <memory>
3837
#include <optional>
3938
#include <vector>
@@ -57,13 +56,13 @@ struct runner;
5756
struct benchmark_base
5857
{
5958
template <typename T>
60-
using optional_ref = std::optional<std::reference_wrapper<T>>;
59+
using optional_ptr = std::optional<T *>;
6160

6261
template <typename TypeAxes>
6362
explicit benchmark_base(TypeAxes type_axes)
6463
: m_axes(type_axes)
6564
{
66-
this->init_printer_ref();
65+
this->init_printer_ptr();
6766
this->set_stopping_criterion(nvbench::detail::default_stopping_criterion());
6867
}
6968

@@ -161,7 +160,7 @@ struct benchmark_base
161160
void set_printer(nvbench::printer_base &printer);
162161
void clear_printer();
163162

164-
[[nodiscard]] optional_ref<nvbench::printer_base> get_printer() const;
163+
[[nodiscard]] optional_ptr<nvbench::printer_base> get_printer() const;
165164

166165
/// Execute at least this many trials per measurement. @{
167166
[[nodiscard]] nvbench::int64_t get_min_samples() const { return m_min_samples; }
@@ -338,16 +337,16 @@ protected:
338337
std::string m_stopping_criterion{};
339338

340339
private:
341-
struct printer_optional_ref_impl_t;
340+
struct printer_optional_ptr_impl_t;
342341

343-
struct printer_optional_ref_deleter_t
342+
struct printer_optional_ptr_deleter_t
344343
{
345-
void operator()(printer_optional_ref_impl_t *) const noexcept;
344+
void operator()(printer_optional_ptr_impl_t *) const noexcept;
346345
};
347346

348-
std::unique_ptr<printer_optional_ref_impl_t, printer_optional_ref_deleter_t> m_printer_wrapper;
347+
std::unique_ptr<printer_optional_ptr_impl_t, printer_optional_ptr_deleter_t> m_printer_wrapper;
349348

350-
void init_printer_ref();
349+
void init_printer_ptr();
351350

352351
// route these through virtuals so the templated subclass can inject type info
353352
virtual std::unique_ptr<benchmark_base> do_clone() const = 0;

nvbench/benchmark_base.cxx

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,34 @@
2929
namespace nvbench
3030
{
3131

32-
struct benchmark_base::printer_optional_ref_impl_t
32+
struct benchmark_base::printer_optional_ptr_impl_t
3333
{
34-
benchmark_base::optional_ref<printer_base> optional_ref;
34+
benchmark_base::optional_ptr<printer_base> optional_ptr;
3535
};
3636

3737
benchmark_base::~benchmark_base() = default;
3838

39-
void benchmark_base::printer_optional_ref_deleter_t::operator()(
40-
printer_optional_ref_impl_t *p) const noexcept
39+
void benchmark_base::printer_optional_ptr_deleter_t::operator()(
40+
printer_optional_ptr_impl_t *p) const noexcept
4141
{
4242
delete p;
4343
}
4444

45-
void benchmark_base::init_printer_ref()
45+
void benchmark_base::init_printer_ptr()
4646
{
47-
m_printer_wrapper.reset(new printer_optional_ref_impl_t{});
47+
m_printer_wrapper.reset(new printer_optional_ptr_impl_t{});
4848
}
4949

5050
void benchmark_base::set_printer(nvbench::printer_base &printer)
5151
{
52-
m_printer_wrapper->optional_ref = std::ref(printer);
52+
m_printer_wrapper->optional_ptr = &printer;
5353
}
5454

55-
void benchmark_base::clear_printer() { m_printer_wrapper->optional_ref = std::nullopt; }
55+
void benchmark_base::clear_printer() { m_printer_wrapper->optional_ptr = std::nullopt; }
5656

57-
benchmark_base::optional_ref<nvbench::printer_base> benchmark_base::get_printer() const
57+
benchmark_base::optional_ptr<nvbench::printer_base> benchmark_base::get_printer() const
5858
{
59-
return m_printer_wrapper->optional_ref;
59+
return m_printer_wrapper->optional_ptr;
6060
}
6161

6262
std::unique_ptr<benchmark_base> benchmark_base::clone() const
@@ -68,8 +68,8 @@ std::unique_ptr<benchmark_base> benchmark_base::clone() const
6868
result->m_axes = m_axes;
6969
result->m_devices = m_devices;
7070

71-
result->m_printer_wrapper.reset(new printer_optional_ref_impl_t{m_printer_wrapper->optional_ref});
72-
result->m_printer_wrapper->optional_ref = m_printer_wrapper->optional_ref;
71+
result->m_printer_wrapper.reset(new printer_optional_ptr_impl_t{m_printer_wrapper->optional_ptr});
72+
result->m_printer_wrapper->optional_ptr = m_printer_wrapper->optional_ptr;
7373

7474
result->m_is_cpu_only = m_is_cpu_only;
7575
result->m_run_once = m_run_once;

nvbench/detail/measure_cold.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,9 @@ void measure_cold_base::record_measurements()
115115
0.5f);
116116
}
117117

118-
if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value())
118+
if (auto printer_opt_ptr = m_state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
119119
{
120-
auto &printer = printer_opt_ref.value().get();
120+
auto &printer = *(printer_opt_ptr.value());
121121
printer.log(nvbench::log_level::warn,
122122
fmt::format("GPU throttled below threshold ({:0.2f} MHz / {:0.2f} MHz) "
123123
"({:0.0f}% < {:0.0f}%) on sample {}. Discarding previous trial "
@@ -386,9 +386,9 @@ void measure_cold_base::generate_summaries()
386386
}
387387

388388
// Log if a printer exists:
389-
if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value())
389+
if (auto printer_opt_ptr = m_state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
390390
{
391-
auto &printer = printer_opt_ref.value().get();
391+
auto &printer = *(printer_opt_ptr.value());
392392

393393
if (m_max_time_exceeded)
394394
{

nvbench/detail/measure_cpu_only.cxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ void measure_cpu_only_base::generate_summaries()
204204
}
205205

206206
// Log if a printer exists:
207-
if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value())
207+
if (auto printer_opt_ptr = m_state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
208208
{
209-
auto &printer = printer_opt_ref.value().get();
209+
auto &printer = *(printer_opt_ptr.value());
210210

211211
if (m_max_time_exceeded)
212212
{

nvbench/detail/measure_cupti.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ try
171171
// clang-format on
172172
catch (const std::exception &ex)
173173
{
174-
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); printer_opt_ref)
174+
if (auto printer_opt_ptr = exec_state.get_benchmark().get_printer(); printer_opt_ptr)
175175
{
176-
auto &printer = printer_opt_ref.value().get();
176+
auto &printer = *(printer_opt_ptr.value());
177177
printer.log(nvbench::log_level::warn,
178178
fmt::format("CUPTI failed to construct profiler: {}", ex.what()));
179179
}
@@ -247,9 +247,9 @@ try
247247
}
248248

249249
// Log if a printer exists:
250-
if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value())
250+
if (auto printer_opt_ptr = m_state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
251251
{
252-
auto &printer = printer_opt_ref.value().get();
252+
auto &printer = *(printer_opt_ptr.value());
253253
printer.log(nvbench::log_level::pass,
254254
fmt::format("CUPTI: {:0.2f}s total wall, {}x",
255255
m_walltime_timer.get_duration(),
@@ -258,9 +258,9 @@ try
258258
}
259259
catch (const std::exception &ex)
260260
{
261-
if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref)
261+
if (auto printer_opt_ptr = m_state.get_benchmark().get_printer(); printer_opt_ptr)
262262
{
263-
auto &printer = printer_opt_ref.value().get();
263+
auto &printer = *(printer_opt_ptr.value());
264264
printer.log(nvbench::log_level::warn,
265265
fmt::format("CUPTI failed to generate the summary: {}", ex.what()));
266266
}

nvbench/detail/measure_hot.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ void measure_hot_base::generate_summaries()
112112
}
113113

114114
// Log if a printer exists:
115-
if (auto printer_opt_ref = m_state.get_benchmark().get_printer(); printer_opt_ref.has_value())
115+
if (auto printer_opt_ptr = m_state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
116116
{
117-
auto &printer = printer_opt_ref.value().get();
117+
auto &printer = *(printer_opt_ptr.value());
118118

119119
// Warn if timed out:
120120
if (m_max_time_exceeded)

nvbench/json_printer.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ void json_printer::do_process_bulk_data_float64(state &state,
247247
}
248248
catch (std::exception &e)
249249
{
250-
if (auto printer_opt_ref = state.get_benchmark().get_printer(); printer_opt_ref.has_value())
250+
if (auto printer_opt_ptr = state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
251251
{
252-
auto &printer = printer_opt_ref.value().get();
252+
auto &printer = *(printer_opt_ptr.value());
253253
printer.log(
254254
nvbench::log_level::warn,
255255
fmt::format("Error writing {} ({}) to {}: {}", tag, hint, result_path.string(), e.what()));
@@ -267,9 +267,9 @@ void json_printer::do_process_bulk_data_float64(state &state,
267267
summ.set_string("hide", "Not needed in table.");
268268

269269
timer.stop();
270-
if (auto printer_opt_ref = state.get_benchmark().get_printer(); printer_opt_ref.has_value())
270+
if (auto printer_opt_ptr = state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
271271
{
272-
auto &printer = printer_opt_ref.value().get();
272+
auto &printer = *(printer_opt_ptr.value());
273273
printer.log(
274274
nvbench::log_level::info,
275275
fmt::format("Wrote '{}' in {:>6.3f}ms", result_path.string(), timer.get_duration() * 1000));
@@ -307,9 +307,9 @@ void json_printer::do_process_bulk_data_float64(state &state,
307307
}
308308
catch (std::exception &e)
309309
{
310-
if (auto printer_opt_ref = state.get_benchmark().get_printer(); printer_opt_ref.has_value())
310+
if (auto printer_opt_ptr = state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
311311
{
312-
auto &printer = printer_opt_ref.value().get();
312+
auto &printer = *(printer_opt_ptr.value());
313313
printer.log(
314314
nvbench::log_level::warn,
315315
fmt::format("Error writing {} ({}) to {}: {}", tag, hint, result_path.string(), e.what()));
@@ -327,9 +327,9 @@ void json_printer::do_process_bulk_data_float64(state &state,
327327
summ.set_string("hide", "Not needed in table.");
328328

329329
timer.stop();
330-
if (auto printer_opt_ref = state.get_benchmark().get_printer(); printer_opt_ref.has_value())
330+
if (auto printer_opt_ptr = state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
331331
{
332-
auto &printer = printer_opt_ref.value().get();
332+
auto &printer = *(printer_opt_ptr.value());
333333
printer.log(
334334
nvbench::log_level::info,
335335
fmt::format("Wrote '{}' in {:>6.3f}ms", result_path.string(), timer.get_duration() * 1000));

nvbench/runner.cxx

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ void runner_base::handle_sampling_exception(const std::exception &e, state &exec
4646
{
4747
const auto reason = fmt::format("Unexpected error: {}", e.what());
4848

49-
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer();
50-
printer_opt_ref.has_value())
49+
if (auto printer_opt_ptr = exec_state.get_benchmark().get_printer();
50+
printer_opt_ptr.has_value())
5151
{
52-
auto &printer = printer_opt_ref.value().get();
52+
auto &printer = *(printer_opt_ptr.value());
5353
printer.log(nvbench::log_level::fail, reason);
5454
}
5555

@@ -60,28 +60,28 @@ void runner_base::handle_sampling_exception(const std::exception &e, state &exec
6060
void runner_base::run_state_prologue(nvbench::state &exec_state) const
6161
{
6262
// Log if a printer exists:
63-
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); printer_opt_ref.has_value())
63+
if (auto printer_opt_ptr = exec_state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
6464
{
65-
auto &printer = printer_opt_ref.value().get();
65+
auto &printer = *(printer_opt_ptr.value());
6666
printer.log_run_state(exec_state);
6767
}
6868
}
6969

7070
void runner_base::run_state_epilogue(state &exec_state) const
7171
{
7272
// Notify the printer that the state has completed::
73-
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); printer_opt_ref.has_value())
73+
if (auto printer_opt_ptr = exec_state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
7474
{
75-
auto &printer = printer_opt_ref.value().get();
75+
auto &printer = *(printer_opt_ptr.value());
7676
printer.add_completed_state();
7777
}
7878
}
7979

8080
void runner_base::print_skip_notification(state &exec_state) const
8181
{
82-
if (auto printer_opt_ref = exec_state.get_benchmark().get_printer(); printer_opt_ref.has_value())
82+
if (auto printer_opt_ptr = exec_state.get_benchmark().get_printer(); printer_opt_ptr.has_value())
8383
{
84-
auto &printer = printer_opt_ref.value().get();
84+
auto &printer = *(printer_opt_ptr.value());
8585
printer.log(nvbench::log_level::skip, exec_state.get_skip_reason());
8686
}
8787
}

0 commit comments

Comments
 (0)