99// ===----------------------------------------------------------------------===//
1010
1111#include < cub/detail/launcher/cuda_driver.cuh>
12+ #include < cub/detail/ptx-json-parser.cuh>
1213#include < cub/device/device_histogram.cuh>
1314
1415#include < cuda/std/algorithm>
2728#include < nvrtc/ltoir_list_appender.h>
2829#include < util/build_utils.h>
2930
31+ struct device_histogram_policy ;
32+
3033// int32_t is generally faster. Depending on the number of samples we
3134// instantiate the kernels below with int32 or int64, but we set this to int64
3235// here because it's needed for host computation as well.
@@ -38,34 +41,31 @@ namespace histogram
3841{
3942struct histogram_runtime_tuning_policy
4043{
41- int block_threads;
42- int pixels_per_thread;
44+ cub::detail::RuntimeHistogramAgentPolicy histogram;
4345
44- int BlockThreads () const
46+ auto Histogram () const
4547 {
46- return block_threads ;
48+ return histogram ;
4749 }
4850
49- int PixelsPerThread () const
51+ CUB_RUNTIME_FUNCTION int BlockThreads () const
5052 {
51- return pixels_per_thread ;
53+ return histogram. BlockThreads () ;
5254 }
53- };
5455
55- template <auto * GetPolicy>
56- struct dynamic_histogram_policy_t
57- {
58- using MaxPolicy = dynamic_histogram_policy_t ;
56+ CUB_RUNTIME_FUNCTION int PixelsPerThread () const
57+ {
58+ return histogram.PixelsPerThread ();
59+ }
60+
61+ using HistogramPolicy = cub::detail::RuntimeHistogramAgentPolicy;
62+ using MaxPolicy = histogram_runtime_tuning_policy;
5963
6064 template <typename F>
61- cudaError_t Invoke (int device_ptx_version , F& op)
65+ cudaError_t Invoke (int , F& op)
6266 {
63- return op.template Invoke <histogram_runtime_tuning_policy>(
64- GetPolicy (device_ptx_version, sample_t , num_active_channels));
67+ return op.template Invoke <histogram_runtime_tuning_policy>(*this );
6568 }
66-
67- cccl_type_info sample_t ;
68- int num_active_channels;
6969};
7070
7171struct histogram_kernel_source
@@ -105,19 +105,11 @@ struct histogram_kernel_source
105105 }
106106};
107107
108- histogram_runtime_tuning_policy get_policy (int /* cc */ , cccl_type_info sample_t , int num_active_channels )
108+ std::string get_init_kernel_name (int num_active_channels, std::string_view counter_t , std::string_view offset_t )
109109{
110- const int v_scale = static_cast <int >(cuda::ceil_div (sample_t .size , sizeof (int )));
111- constexpr int nominal_items_per_thread = 16 ;
112-
113- int pixels_per_thread = (::cuda::std::max) (nominal_items_per_thread / num_active_channels / v_scale, 1 );
114-
115- return {384 , pixels_per_thread};
116- }
110+ std::string chained_policy_t ;
111+ check (cccl_type_name_from_nvrtc<device_histogram_policy>(&chained_policy_t ));
117112
118- std::string get_init_kernel_name (
119- std::string_view chained_policy_t , int num_active_channels, std::string_view counter_t , std::string_view offset_t )
120- {
121113 return std::format (
122114 " cub::detail::histogram::DeviceHistogramInitKernel<{0}, {1}, {2}, {3}>" ,
123115 chained_policy_t ,
@@ -127,7 +119,6 @@ std::string get_init_kernel_name(
127119}
128120
129121std::string get_sweep_kernel_name (
130- std::string_view chained_policy_t ,
131122 int privatized_smem_bins,
132123 int num_channels,
133124 int num_active_channels,
@@ -138,6 +129,9 @@ std::string get_sweep_kernel_name(
138129 bool is_evenly_segmented,
139130 bool is_byte_sample)
140131{
132+ std::string chained_policy_t ;
133+ check (cccl_type_name_from_nvrtc<device_histogram_policy>(&chained_policy_t ));
134+
141135 std::string samples_iterator_name;
142136 check (cccl_type_name_from_nvrtc<samples_iterator_t >(&samples_iterator_name));
143137
@@ -286,7 +280,6 @@ CUresult cccl_device_histogram_build_ex(
286280 const char * name = " test" ;
287281
288282 const int cc = cc_major * 10 + cc_minor;
289- const auto policy = histogram::get_policy (cc, d_samples.value_type , num_active_channels);
290283 const auto sample_cpp = cccl_type_enum_to_name (d_samples.value_type .type );
291284 const auto counter_cpp = cccl_type_enum_to_name (d_output_histograms.value_type .type );
292285 const auto level_cpp = cccl_type_enum_to_name (lower_level.type .type );
@@ -302,48 +295,42 @@ CUresult cccl_device_histogram_build_ex(
302295 const std::string samples_iterator_src =
303296 make_kernel_input_iterator (offset_cpp, samples_iterator_name, sample_cpp, d_samples);
304297
305- constexpr std::string_view chained_policy_t = " device_histogram_policy" ;
298+ std::string policy_hub_expr = std::format (
299+ " cub::detail::histogram::policy_hub<{}, {}, {}, {}, {}>" ,
300+ sample_cpp,
301+ counter_cpp,
302+ num_channels,
303+ num_active_channels,
304+ is_evenly_segmented ? " true" : " false" );
306305
307- constexpr std::string_view src_template = R"XXX(
306+ std::string final_src = std::format (
307+ R"XXX(
308308#include <cub/agent/agent_histogram.cuh>
309309#include <cub/block/block_load.cuh>
310310#include <cub/device/dispatch/kernels/kernel_histogram.cuh>
311+ #include <cub/device/dispatch/tuning/tuning_histogram.cuh>
311312
312313struct __align__({1}) storage_t {{
313314 char data[{0}];
314315}};
315316{2}
316- struct agent_policy_t {{
317- static constexpr int BLOCK_THREADS = {3};
318- static constexpr int PIXELS_PER_THREAD = {4};
319- static constexpr bool IS_RLE_COMPRESS = true;
320- static constexpr cub::BlockHistogramMemoryPreference MEM_PREFERENCE = cub::SMEM;
321- static constexpr bool IS_WORK_STEALING = false;
322- static constexpr int VEC_SIZE = 4;
323- static constexpr cub::BlockLoadAlgorithm LOAD_ALGORITHM = cub::BLOCK_LOAD_DIRECT;
324- static constexpr cub::CacheLoadModifier LOAD_MODIFIER = cub::LOAD_LDG;
325- }};
326- struct {5} {{
327- struct ActivePolicy {{
328- using AgentHistogramPolicyT = agent_policy_t;
329- static constexpr int pdl_trigger_next_launch_in_init_kernel_max_bin_count = 2048;
330- }};
331- }};
332- )XXX" ;
333-
334- const std::string src = std::format (
335- src_template,
317+ using device_histogram_policy = {3}::MaxPolicy;
318+
319+ #include <cub/detail/ptx-json/json.cuh>
320+ __device__ consteval auto& policy_generator() {{
321+ return ptx_json::id<ptx_json::string("device_histogram_policy")>()
322+ = cub::detail::histogram::HistogramPolicyWrapper<device_histogram_policy::ActivePolicy>::EncodedPolicy();
323+ }}
324+ )XXX" ,
336325 d_samples.value_type .size , // 0
337326 d_samples.value_type .alignment , // 1
338327 samples_iterator_src, // 2
339- policy.block_threads , // 3
340- policy.pixels_per_thread , // 4
341- chained_policy_t // 5
328+ policy_hub_expr // 3
342329 );
343330
344331#if false // CCCL_DEBUGGING_SWITCH
345332 fflush (stderr);
346- printf (" \n CODE4NVRTC BEGIN\n %sCODE4NVRTC END\n " , src .c_str ());
333+ printf (" \n CODE4NVRTC BEGIN\n %sCODE4NVRTC END\n " , final_src .c_str ());
347334 fflush (stdout);
348335#endif
349336
@@ -355,10 +342,8 @@ struct {5} {{
355342
356343 const bool is_byte_sample = d_samples.value_type .size == 1 ;
357344
358- std::string init_kernel_name =
359- histogram::get_init_kernel_name (chained_policy_t , num_active_channels, counter_cpp, offset_cpp);
345+ std::string init_kernel_name = histogram::get_init_kernel_name (num_active_channels, counter_cpp, offset_cpp);
360346 std::string sweep_kernel_name = histogram::get_sweep_kernel_name (
361- chained_policy_t ,
362347 privatized_smem_bins,
363348 num_channels,
364349 num_active_channels,
@@ -374,8 +359,20 @@ struct {5} {{
374359
375360 const std::string arch = std::format (" -arch=sm_{0}{1}" , cc_major, cc_minor);
376361
362+ // Note: `-default-device` is needed because of the constexpr functions in
363+ // tuning_histogram.cuh
377364 std::vector<const char *> args = {
378- arch.c_str (), cub_path, thrust_path, libcudacxx_path, ctk_path, " -rdc=true" , " -dlto" , " -DCUB_DISABLE_CDP" };
365+ arch.c_str (),
366+ cub_path,
367+ thrust_path,
368+ libcudacxx_path,
369+ ctk_path,
370+ " -rdc=true" ,
371+ " -dlto" ,
372+ " -default-device" ,
373+ " -DCUB_DISABLE_CDP" ,
374+ " -DCUB_ENABLE_POLICY_PTX_JSON" ,
375+ " -std=c++20" };
379376
380377 cccl::detail::extend_args_with_build_config (args, config);
381378
@@ -390,7 +387,7 @@ struct {5} {{
390387
391388 nvrtc_link_result result =
392389 begin_linking_nvrtc_program (num_lto_args, lopts)
393- ->add_program (nvrtc_translation_unit ({src .c_str (), name}))
390+ ->add_program (nvrtc_translation_unit ({final_src .c_str (), name}))
394391 ->add_expression ({init_kernel_name})
395392 ->add_expression ({sweep_kernel_name})
396393 ->compile_program ({args.data (), args.size ()})
@@ -404,6 +401,12 @@ struct {5} {{
404401 check (cuLibraryGetKernel (&build_ptr->init_kernel , build_ptr->library , init_kernel_lowered_name.c_str ()));
405402 check (cuLibraryGetKernel (&build_ptr->sweep_kernel , build_ptr->library , sweep_kernel_lowered_name.c_str ()));
406403
404+ nlohmann::json runtime_policy =
405+ cub::detail::ptx_json::parse (" device_histogram_policy" , {result.data .get (), result.size });
406+
407+ using cub::detail::RuntimeHistogramAgentPolicy;
408+ auto histogram_policy = RuntimeHistogramAgentPolicy::from_json (runtime_policy, " HistogramPolicy" );
409+
407410 build_ptr->cc = cc;
408411 build_ptr->cubin = (void *) result.data .release ();
409412 build_ptr->cubin_size = result.size ;
@@ -413,6 +416,7 @@ struct {5} {{
413416 build_ptr->num_active_channels = num_active_channels;
414417 build_ptr->may_overflow = false ; // This is set in cccl_device_histogram_even_impl so that kernel source can access
415418 // it later.
419+ build_ptr->runtime_policy = new histogram::histogram_runtime_tuning_policy{histogram_policy};
416420 }
417421 catch (const std::exception& exc)
418422 {
@@ -477,7 +481,7 @@ CUresult cccl_device_histogram_even_impl(
477481 indirect_arg_t , // CounterT
478482 indirect_arg_t , // LevelT
479483 OffsetT, // OffsetT
480- histogram::dynamic_histogram_policy_t <&histogram::get_policy> , // PolicyHub
484+ histogram::histogram_runtime_tuning_policy , // PolicyHub
481485 indirect_arg_t , // SampleT
482486 histogram::histogram_kernel_source, // KernelSource
483487 cub::detail::CudaDriverLauncherFactory // KernelLauncherFactory
@@ -497,7 +501,7 @@ CUresult cccl_device_histogram_even_impl(
497501 is_byte_sample{},
498502 {build},
499503 cub::detail::CudaDriverLauncherFactory{cu_device, build.cc },
500- {d_samples. value_type , build.num_active_channels } );
504+ * reinterpret_cast <histogram::histogram_runtime_tuning_policy*>( build.runtime_policy ) );
501505
502506 error = static_cast <CUresult>(exec_status);
503507 }
@@ -598,6 +602,7 @@ CUresult cccl_device_histogram_cleanup(cccl_device_histogram_build_result_t* bui
598602 }
599603
600604 std::unique_ptr<char []> cubin (reinterpret_cast <char *>(build_ptr->cubin ));
605+ std::unique_ptr<char []> policy (reinterpret_cast <char *>(build_ptr->runtime_policy ));
601606 check (cuLibraryUnload (build_ptr->library ));
602607 }
603608 catch (const std::exception& exc)
0 commit comments