Skip to content

Commit db78d42

Browse files
authored
c.parallel: reuse CUB agent policies for histogram (#6974)
1 parent 34f5839 commit db78d42

File tree

4 files changed

+110
-68
lines changed

4 files changed

+110
-68
lines changed

c/parallel/include/cccl/c/histogram.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ typedef struct cccl_device_histogram_build_result_t
3636
bool may_overflow;
3737
CUkernel init_kernel;
3838
CUkernel sweep_kernel;
39+
void* runtime_policy;
3940
} cccl_device_histogram_build_result_t;
4041

4142
CCCL_C_API CUresult cccl_device_histogram_build(

c/parallel/src/histogram.cu

Lines changed: 68 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include <cub/detail/launcher/cuda_driver.cuh>
12+
#include <cub/detail/ptx-json-parser.cuh>
1213
#include <cub/device/device_histogram.cuh>
1314

1415
#include <cuda/std/algorithm>
@@ -27,6 +28,8 @@
2728
#include <nvrtc/ltoir_list_appender.h>
2829
#include <util/build_utils.h>
2930

31+
struct device_histogram_policy;
32+
3033
// int32_t is generally faster. Depending on the number of samples we
3134
// instantiate the kernels below with int32 or int64, but we set this to int64
3235
// here because it's needed for host computation as well.
@@ -38,34 +41,31 @@ namespace histogram
3841
{
3942
struct histogram_runtime_tuning_policy
4043
{
41-
int block_threads;
42-
int pixels_per_thread;
44+
cub::detail::RuntimeHistogramAgentPolicy histogram;
4345

44-
int BlockThreads() const
46+
auto Histogram() const
4547
{
46-
return block_threads;
48+
return histogram;
4749
}
4850

49-
int PixelsPerThread() const
51+
CUB_RUNTIME_FUNCTION int BlockThreads() const
5052
{
51-
return pixels_per_thread;
53+
return histogram.BlockThreads();
5254
}
53-
};
5455

55-
template <auto* GetPolicy>
56-
struct dynamic_histogram_policy_t
57-
{
58-
using MaxPolicy = dynamic_histogram_policy_t;
56+
CUB_RUNTIME_FUNCTION int PixelsPerThread() const
57+
{
58+
return histogram.PixelsPerThread();
59+
}
60+
61+
using HistogramPolicy = cub::detail::RuntimeHistogramAgentPolicy;
62+
using MaxPolicy = histogram_runtime_tuning_policy;
5963

6064
template <typename F>
61-
cudaError_t Invoke(int device_ptx_version, F& op)
65+
cudaError_t Invoke(int, F& op)
6266
{
63-
return op.template Invoke<histogram_runtime_tuning_policy>(
64-
GetPolicy(device_ptx_version, sample_t, num_active_channels));
67+
return op.template Invoke<histogram_runtime_tuning_policy>(*this);
6568
}
66-
67-
cccl_type_info sample_t;
68-
int num_active_channels;
6969
};
7070

7171
struct histogram_kernel_source
@@ -105,19 +105,11 @@ struct histogram_kernel_source
105105
}
106106
};
107107

108-
histogram_runtime_tuning_policy get_policy(int /*cc*/, cccl_type_info sample_t, int num_active_channels)
108+
std::string get_init_kernel_name(int num_active_channels, std::string_view counter_t, std::string_view offset_t)
109109
{
110-
const int v_scale = static_cast<int>(cuda::ceil_div(sample_t.size, sizeof(int)));
111-
constexpr int nominal_items_per_thread = 16;
112-
113-
int pixels_per_thread = (::cuda::std::max) (nominal_items_per_thread / num_active_channels / v_scale, 1);
114-
115-
return {384, pixels_per_thread};
116-
}
110+
std::string chained_policy_t;
111+
check(cccl_type_name_from_nvrtc<device_histogram_policy>(&chained_policy_t));
117112

118-
std::string get_init_kernel_name(
119-
std::string_view chained_policy_t, int num_active_channels, std::string_view counter_t, std::string_view offset_t)
120-
{
121113
return std::format(
122114
"cub::detail::histogram::DeviceHistogramInitKernel<{0}, {1}, {2}, {3}>",
123115
chained_policy_t,
@@ -127,7 +119,6 @@ std::string get_init_kernel_name(
127119
}
128120

129121
std::string get_sweep_kernel_name(
130-
std::string_view chained_policy_t,
131122
int privatized_smem_bins,
132123
int num_channels,
133124
int num_active_channels,
@@ -138,6 +129,9 @@ std::string get_sweep_kernel_name(
138129
bool is_evenly_segmented,
139130
bool is_byte_sample)
140131
{
132+
std::string chained_policy_t;
133+
check(cccl_type_name_from_nvrtc<device_histogram_policy>(&chained_policy_t));
134+
141135
std::string samples_iterator_name;
142136
check(cccl_type_name_from_nvrtc<samples_iterator_t>(&samples_iterator_name));
143137

@@ -286,7 +280,6 @@ CUresult cccl_device_histogram_build_ex(
286280
const char* name = "test";
287281

288282
const int cc = cc_major * 10 + cc_minor;
289-
const auto policy = histogram::get_policy(cc, d_samples.value_type, num_active_channels);
290283
const auto sample_cpp = cccl_type_enum_to_name(d_samples.value_type.type);
291284
const auto counter_cpp = cccl_type_enum_to_name(d_output_histograms.value_type.type);
292285
const auto level_cpp = cccl_type_enum_to_name(lower_level.type.type);
@@ -302,48 +295,42 @@ CUresult cccl_device_histogram_build_ex(
302295
const std::string samples_iterator_src =
303296
make_kernel_input_iterator(offset_cpp, samples_iterator_name, sample_cpp, d_samples);
304297

305-
constexpr std::string_view chained_policy_t = "device_histogram_policy";
298+
std::string policy_hub_expr = std::format(
299+
"cub::detail::histogram::policy_hub<{}, {}, {}, {}, {}>",
300+
sample_cpp,
301+
counter_cpp,
302+
num_channels,
303+
num_active_channels,
304+
is_evenly_segmented ? "true" : "false");
306305

307-
constexpr std::string_view src_template = R"XXX(
306+
std::string final_src = std::format(
307+
R"XXX(
308308
#include <cub/agent/agent_histogram.cuh>
309309
#include <cub/block/block_load.cuh>
310310
#include <cub/device/dispatch/kernels/kernel_histogram.cuh>
311+
#include <cub/device/dispatch/tuning/tuning_histogram.cuh>
311312
312313
struct __align__({1}) storage_t {{
313314
char data[{0}];
314315
}};
315316
{2}
316-
struct agent_policy_t {{
317-
static constexpr int BLOCK_THREADS = {3};
318-
static constexpr int PIXELS_PER_THREAD = {4};
319-
static constexpr bool IS_RLE_COMPRESS = true;
320-
static constexpr cub::BlockHistogramMemoryPreference MEM_PREFERENCE = cub::SMEM;
321-
static constexpr bool IS_WORK_STEALING = false;
322-
static constexpr int VEC_SIZE = 4;
323-
static constexpr cub::BlockLoadAlgorithm LOAD_ALGORITHM = cub::BLOCK_LOAD_DIRECT;
324-
static constexpr cub::CacheLoadModifier LOAD_MODIFIER = cub::LOAD_LDG;
325-
}};
326-
struct {5} {{
327-
struct ActivePolicy {{
328-
using AgentHistogramPolicyT = agent_policy_t;
329-
static constexpr int pdl_trigger_next_launch_in_init_kernel_max_bin_count = 2048;
330-
}};
331-
}};
332-
)XXX";
333-
334-
const std::string src = std::format(
335-
src_template,
317+
using device_histogram_policy = {3}::MaxPolicy;
318+
319+
#include <cub/detail/ptx-json/json.cuh>
320+
__device__ consteval auto& policy_generator() {{
321+
return ptx_json::id<ptx_json::string("device_histogram_policy")>()
322+
= cub::detail::histogram::HistogramPolicyWrapper<device_histogram_policy::ActivePolicy>::EncodedPolicy();
323+
}}
324+
)XXX",
336325
d_samples.value_type.size, // 0
337326
d_samples.value_type.alignment, // 1
338327
samples_iterator_src, // 2
339-
policy.block_threads, // 3
340-
policy.pixels_per_thread, // 4
341-
chained_policy_t // 5
328+
policy_hub_expr // 3
342329
);
343330

344331
#if false // CCCL_DEBUGGING_SWITCH
345332
fflush(stderr);
346-
printf("\nCODE4NVRTC BEGIN\n%sCODE4NVRTC END\n", src.c_str());
333+
printf("\nCODE4NVRTC BEGIN\n%sCODE4NVRTC END\n", final_src.c_str());
347334
fflush(stdout);
348335
#endif
349336

@@ -355,10 +342,8 @@ struct {5} {{
355342

356343
const bool is_byte_sample = d_samples.value_type.size == 1;
357344

358-
std::string init_kernel_name =
359-
histogram::get_init_kernel_name(chained_policy_t, num_active_channels, counter_cpp, offset_cpp);
345+
std::string init_kernel_name = histogram::get_init_kernel_name(num_active_channels, counter_cpp, offset_cpp);
360346
std::string sweep_kernel_name = histogram::get_sweep_kernel_name(
361-
chained_policy_t,
362347
privatized_smem_bins,
363348
num_channels,
364349
num_active_channels,
@@ -374,8 +359,20 @@ struct {5} {{
374359

375360
const std::string arch = std::format("-arch=sm_{0}{1}", cc_major, cc_minor);
376361

362+
// Note: `-default-device` is needed because of the constexpr functions in
363+
// tuning_histogram.cuh
377364
std::vector<const char*> args = {
378-
arch.c_str(), cub_path, thrust_path, libcudacxx_path, ctk_path, "-rdc=true", "-dlto", "-DCUB_DISABLE_CDP"};
365+
arch.c_str(),
366+
cub_path,
367+
thrust_path,
368+
libcudacxx_path,
369+
ctk_path,
370+
"-rdc=true",
371+
"-dlto",
372+
"-default-device",
373+
"-DCUB_DISABLE_CDP",
374+
"-DCUB_ENABLE_POLICY_PTX_JSON",
375+
"-std=c++20"};
379376

380377
cccl::detail::extend_args_with_build_config(args, config);
381378

@@ -390,7 +387,7 @@ struct {5} {{
390387

391388
nvrtc_link_result result =
392389
begin_linking_nvrtc_program(num_lto_args, lopts)
393-
->add_program(nvrtc_translation_unit({src.c_str(), name}))
390+
->add_program(nvrtc_translation_unit({final_src.c_str(), name}))
394391
->add_expression({init_kernel_name})
395392
->add_expression({sweep_kernel_name})
396393
->compile_program({args.data(), args.size()})
@@ -404,6 +401,12 @@ struct {5} {{
404401
check(cuLibraryGetKernel(&build_ptr->init_kernel, build_ptr->library, init_kernel_lowered_name.c_str()));
405402
check(cuLibraryGetKernel(&build_ptr->sweep_kernel, build_ptr->library, sweep_kernel_lowered_name.c_str()));
406403

404+
nlohmann::json runtime_policy =
405+
cub::detail::ptx_json::parse("device_histogram_policy", {result.data.get(), result.size});
406+
407+
using cub::detail::RuntimeHistogramAgentPolicy;
408+
auto histogram_policy = RuntimeHistogramAgentPolicy::from_json(runtime_policy, "HistogramPolicy");
409+
407410
build_ptr->cc = cc;
408411
build_ptr->cubin = (void*) result.data.release();
409412
build_ptr->cubin_size = result.size;
@@ -413,6 +416,7 @@ struct {5} {{
413416
build_ptr->num_active_channels = num_active_channels;
414417
build_ptr->may_overflow = false; // This is set in cccl_device_histogram_even_impl so that kernel source can access
415418
// it later.
419+
build_ptr->runtime_policy = new histogram::histogram_runtime_tuning_policy{histogram_policy};
416420
}
417421
catch (const std::exception& exc)
418422
{
@@ -477,7 +481,7 @@ CUresult cccl_device_histogram_even_impl(
477481
indirect_arg_t, // CounterT
478482
indirect_arg_t, // LevelT
479483
OffsetT, // OffsetT
480-
histogram::dynamic_histogram_policy_t<&histogram::get_policy>, // PolicyHub
484+
histogram::histogram_runtime_tuning_policy, // PolicyHub
481485
indirect_arg_t, // SampleT
482486
histogram::histogram_kernel_source, // KernelSource
483487
cub::detail::CudaDriverLauncherFactory // KernelLauncherFactory
@@ -497,7 +501,7 @@ CUresult cccl_device_histogram_even_impl(
497501
is_byte_sample{},
498502
{build},
499503
cub::detail::CudaDriverLauncherFactory{cu_device, build.cc},
500-
{d_samples.value_type, build.num_active_channels});
504+
*reinterpret_cast<histogram::histogram_runtime_tuning_policy*>(build.runtime_policy));
501505

502506
error = static_cast<CUresult>(exec_status);
503507
}
@@ -598,6 +602,7 @@ CUresult cccl_device_histogram_cleanup(cccl_device_histogram_build_result_t* bui
598602
}
599603

600604
std::unique_ptr<char[]> cubin(reinterpret_cast<char*>(build_ptr->cubin));
605+
std::unique_ptr<char[]> policy(reinterpret_cast<char*>(build_ptr->runtime_policy));
601606
check(cuLibraryUnload(build_ptr->library));
602607
}
603608
catch (const std::exception& exc)

cub/cub/agent/agent_histogram.cuh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,29 @@ struct AgentHistogramPolicy
9595
static constexpr CacheLoadModifier LOAD_MODIFIER = LoadModifier;
9696
};
9797

98+
#if defined(CUB_DEFINE_RUNTIME_POLICIES) || defined(CUB_ENABLE_POLICY_PTX_JSON)
99+
namespace detail
100+
{
101+
// Only define this when needed.
102+
// Because of overload woes, this depends on C++20 concepts. util_device.h checks that concepts are available when
103+
// either runtime policies or PTX JSON information are enabled, so if they are, this is always valid. The generic
104+
// version is always defined, and that's the only one needed for regular CUB operations.
105+
//
106+
// TODO: enable this unconditionally once concepts are always available
107+
CUB_DETAIL_POLICY_WRAPPER_DEFINE(
108+
HistogramAgentPolicy,
109+
(always_true),
110+
(BLOCK_THREADS, BlockThreads, int),
111+
(PIXELS_PER_THREAD, PixelsPerThread, int),
112+
(IS_RLE_COMPRESS, IsRleCompress, bool),
113+
(MEM_PREFERENCE, MemPreference, BlockHistogramMemoryPreference),
114+
(IS_WORK_STEALING, IsWorkStealing, bool),
115+
(VEC_SIZE, VecSize, int),
116+
(LOAD_ALGORITHM, LoadAlgorithm, cub::BlockLoadAlgorithm),
117+
(LOAD_MODIFIER, LoadModifier, cub::CacheLoadModifier))
118+
} // namespace detail
119+
#endif
120+
98121
namespace detail::histogram
99122
{
100123
// Return a native pixel pointer (specialized for CacheModifiedInputIterator types)

cub/cub/device/dispatch/tuning/tuning_histogram.cuh

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ struct sm100_tuning<false, SampleT, 1, 1, counter_size::_4, primitive_sample::ye
145145
template <typename PolicyT, typename = void>
146146
struct HistogramPolicyWrapper : PolicyT
147147
{
148-
CUB_RUNTIME_FUNCTION HistogramPolicyWrapper(PolicyT base)
148+
_CCCL_HOST_DEVICE HistogramPolicyWrapper(PolicyT base)
149149
: PolicyT(base)
150150
{}
151151
};
@@ -155,23 +155,36 @@ struct HistogramPolicyWrapper<StaticPolicyT,
155155
::cuda::std::void_t<decltype(StaticPolicyT::AgentHistogramPolicyT::LOAD_MODIFIER)>>
156156
: StaticPolicyT
157157
{
158-
CUB_RUNTIME_FUNCTION HistogramPolicyWrapper(StaticPolicyT base)
158+
_CCCL_HOST_DEVICE HistogramPolicyWrapper(StaticPolicyT base)
159159
: StaticPolicyT(base)
160160
{}
161161

162-
CUB_RUNTIME_FUNCTION static constexpr int BlockThreads()
162+
_CCCL_HOST_DEVICE static constexpr auto Histogram()
163+
{
164+
return cub::detail::MakePolicyWrapper(typename StaticPolicyT::AgentHistogramPolicyT());
165+
}
166+
167+
_CCCL_HOST_DEVICE static constexpr int BlockThreads()
163168
{
164169
return StaticPolicyT::AgentHistogramPolicyT::BLOCK_THREADS;
165170
}
166171

167-
CUB_RUNTIME_FUNCTION static constexpr int PixelsPerThread()
172+
_CCCL_HOST_DEVICE static constexpr int PixelsPerThread()
168173
{
169174
return StaticPolicyT::AgentHistogramPolicyT::PIXELS_PER_THREAD;
170175
}
176+
177+
#if defined(CUB_ENABLE_POLICY_PTX_JSON)
178+
_CCCL_DEVICE static constexpr auto EncodedPolicy()
179+
{
180+
using namespace ptx_json;
181+
return object<key<"HistogramPolicy">() = Histogram().EncodedPolicy()>();
182+
}
183+
#endif
171184
};
172185

173186
template <typename PolicyT>
174-
CUB_RUNTIME_FUNCTION HistogramPolicyWrapper<PolicyT> MakeHistogramPolicyWrapper(PolicyT policy)
187+
_CCCL_HOST_DEVICE HistogramPolicyWrapper<PolicyT> MakeHistogramPolicyWrapper(PolicyT policy)
175188
{
176189
return HistogramPolicyWrapper<PolicyT>{policy};
177190
}

0 commit comments

Comments
 (0)