Skip to content

Commit 09b7323

Browse files
Implement the new tuning API for DispatchSegmentedSort (#7874)
Fixes: #7643 Co-authored-by: Michael Schellenberger Costa <miscco@nvidia.com>
1 parent 7880d6c commit 09b7323

File tree

10 files changed

+1242
-341
lines changed

10 files changed

+1242
-341
lines changed

c/parallel/src/radix_sort.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ try
229229
const auto policy_sel = cub::detail::radix_sort::policy_selector{
230230
static_cast<int>(input_keys_it.value_type.size),
231231
// FIXME(bgruber): input_values_it.value_type.size is 4 when it represents cub::NullType, which is very odd
232+
// TODO(bgruber): instead of 0 we should probably use int{sizeof(cub::NullType)}
232233
keys_only ? 0 : static_cast<int>(input_values_it.value_type.size),
233234
int{sizeof(OffsetT)},
234235
key_type};

c/parallel/src/segmented_sort.cu

Lines changed: 58 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
#include <nvrtc/ltoir_list_appender.h>
3838
#include <util/build_utils.h>
3939

40-
struct device_segmented_sort_policy;
40+
struct device_segmented_sort_policy_selector;
4141
struct device_three_way_partition_policy;
4242
using OffsetT = ptrdiff_t;
4343
static_assert(std::is_same_v<cub::detail::choose_signed_offset_t<OffsetT>, OffsetT>, "OffsetT must be long");
@@ -55,15 +55,15 @@ std::string get_device_segmented_sort_fallback_kernel_name(
5555
std::string_view value_t,
5656
cccl_sort_order_t sort_order)
5757
{
58-
std::string chained_policy_t;
59-
check(cccl_type_name_from_nvrtc<device_segmented_sort_policy>(&chained_policy_t));
58+
std::string policy_selector_t;
59+
check(cccl_type_name_from_nvrtc<device_segmented_sort_policy_selector>(&policy_selector_t));
6060

6161
std::string offset_t;
6262
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));
6363

6464
/*
6565
template <SortOrder Order, // 0 (ascending)
66-
typename ChainedPolicyT, // 1
66+
typename PolicySelector, // 1
6767
typename KeyT, // 2
6868
typename ValueT, // 3
6969
typename BeginOffsetIteratorT, // 4
@@ -74,7 +74,7 @@ std::string get_device_segmented_sort_fallback_kernel_name(
7474
return std::format(
7575
"cub::detail::segmented_sort::DeviceSegmentedSortFallbackKernel<{0}, {1}, {2}, {3}, {4}, {5}, {6}>",
7676
(sort_order == CCCL_ASCENDING) ? "cub::SortOrder::Ascending" : "cub::SortOrder::Descending",
77-
chained_policy_t, // 0
77+
policy_selector_t, // 0
7878
key_t, // 1
7979
value_t, // 2
8080
start_offset_iterator_t, // 3
@@ -90,7 +90,7 @@ std::string get_device_segmented_sort_kernel_small_name(
9090
cccl_sort_order_t sort_order)
9191
{
9292
std::string chained_policy_t;
93-
check(cccl_type_name_from_nvrtc<device_segmented_sort_policy>(&chained_policy_t));
93+
check(cccl_type_name_from_nvrtc<device_segmented_sort_policy_selector>(&chained_policy_t));
9494

9595
std::string offset_t;
9696
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));
@@ -124,7 +124,7 @@ std::string get_device_segmented_sort_kernel_large_name(
124124
cccl_sort_order_t sort_order)
125125
{
126126
std::string chained_policy_t;
127-
check(cccl_type_name_from_nvrtc<device_segmented_sort_policy>(&chained_policy_t));
127+
check(cccl_type_name_from_nvrtc<device_segmented_sort_policy_selector>(&chained_policy_t));
128128

129129
std::string offset_t;
130130
check(cccl_type_name_from_nvrtc<OffsetT>(&offset_t));
@@ -349,97 +349,6 @@ struct partition_kernel_source
349349
}
350350
};
351351

352-
struct segmented_sort_runtime_tuning_policy
353-
{
354-
cub::detail::RuntimeRadixSortDownsweepAgentPolicy large_segment;
355-
cub::detail::RuntimeSubWarpMergeSortAgentPolicy small_segment;
356-
cub::detail::RuntimeSubWarpMergeSortAgentPolicy medium_segment;
357-
int partitioning_threshold;
358-
359-
auto LargeSegment() const
360-
{
361-
return large_segment;
362-
}
363-
364-
auto SmallSegment() const
365-
{
366-
return small_segment;
367-
}
368-
369-
auto MediumSegment() const
370-
{
371-
return medium_segment;
372-
}
373-
374-
int PartitioningThreshold() const
375-
{
376-
return partitioning_threshold;
377-
}
378-
379-
int LargeSegmentRadixBits() const
380-
{
381-
return large_segment.RadixBits();
382-
}
383-
384-
int SegmentsPerSmallBlock() const
385-
{
386-
return small_segment.SegmentsPerBlock();
387-
}
388-
389-
int SegmentsPerMediumBlock() const
390-
{
391-
return medium_segment.SegmentsPerBlock();
392-
}
393-
394-
int SmallPolicyItemsPerTile() const
395-
{
396-
return small_segment.ItemsPerTile();
397-
}
398-
399-
int MediumPolicyItemsPerTile() const
400-
{
401-
return medium_segment.ItemsPerTile();
402-
}
403-
404-
cub::CacheLoadModifier LargeSegmentLoadModifier() const
405-
{
406-
return large_segment.LoadModifier();
407-
}
408-
409-
cub::BlockLoadAlgorithm LargeSegmentLoadAlgorithm() const
410-
{
411-
return large_segment.LoadAlgorithm();
412-
}
413-
414-
cub::WarpLoadAlgorithm MediumSegmentLoadAlgorithm() const
415-
{
416-
return medium_segment.LoadAlgorithm();
417-
}
418-
419-
cub::WarpLoadAlgorithm SmallSegmentLoadAlgorithm() const
420-
{
421-
return small_segment.LoadAlgorithm();
422-
}
423-
424-
cub::WarpStoreAlgorithm MediumSegmentStoreAlgorithm() const
425-
{
426-
return medium_segment.StoreAlgorithm();
427-
}
428-
429-
cub::WarpStoreAlgorithm SmallSegmentStoreAlgorithm() const
430-
{
431-
return small_segment.StoreAlgorithm();
432-
}
433-
434-
using MaxPolicy = segmented_sort_runtime_tuning_policy;
435-
436-
template <typename F>
437-
cudaError_t Invoke(int, F& op)
438-
{
439-
return op.template Invoke<segmented_sort_runtime_tuning_policy>(*this);
440-
}
441-
};
442-
443352
struct partition_runtime_tuning_policy
444353
{
445354
cub::detail::RuntimeThreeWayPartitionAgentPolicy three_way_partition;
@@ -534,26 +443,22 @@ try
534443

535444
const char* name = "device_segmented_sort";
536445

537-
const int cc = cc_major * 10 + cc_minor;
538-
539446
const auto [keys_in_iterator_name, keys_in_iterator_src] =
540447
get_specialization<segmented_sort_keys_input_iterator_tag>(template_id<input_iterator_traits>(), keys_in_it);
541448

542449
const bool keys_only = values_in_it.type == cccl_iterator_kind_t::CCCL_POINTER && values_in_it.state == nullptr;
543450

544-
std::string values_in_iterator_name, values_in_iterator_src;
451+
std::string values_in_iterator_src;
545452

546453
if (!keys_only)
547454
{
548455
const auto [vi_name, vi_src] =
549456
get_specialization<segmented_sort_values_input_iterator_tag>(template_id<input_iterator_traits>(), values_in_it);
550-
values_in_iterator_name = vi_name;
551-
values_in_iterator_src = vi_src;
457+
values_in_iterator_src = vi_src;
552458
}
553459
else
554460
{
555-
values_in_iterator_name = "cub::NullType*";
556-
values_in_iterator_src = "";
461+
values_in_iterator_src = "";
557462
}
558463

559464
const auto [start_offset_iterator_name, start_offset_iterator_src] =
@@ -624,8 +529,17 @@ try
624529
const auto [small_selector_name, small_selector_src] = get_specialization<segmented_sort_small_selector_tag>(
625530
template_id<user_operation_traits>(), small_selector_op, selector_result_t, selector_input_t);
626531

627-
const auto segmented_sort_policy_hub_expr = std::format(
628-
"cub::detail::segmented_sort::policy_hub<{0}, {1}>",
532+
const auto policy_sel = cub::detail::segmented_sort::policy_selector{
533+
static_cast<int>(keys_in_it.value_type.size),
534+
keys_only ? int{sizeof(cub::NullType)} : static_cast<int>(values_in_it.value_type.size),
535+
keys_only};
536+
537+
// TODO(bgruber): drop this if tuning policies become formattable
538+
std::stringstream segmented_sort_policy_sel_str;
539+
segmented_sort_policy_sel_str << policy_sel(cuda::to_arch_id(cuda::compute_capability{cc_major, cc_minor}));
540+
541+
const auto segmented_sort_policy_expr = std::format(
542+
"cub::detail::segmented_sort::policy_selector_from_types<{}, {}>",
629543
key_t, // 0
630544
value_t); // 1
631545

@@ -658,14 +572,15 @@ struct __align__({4}) items_storage_t {{
658572
{8}
659573
{9}
660574
{10}
661-
using device_segmented_sort_policy = {11}::MaxPolicy;
662-
using device_three_way_partition_policy = {12}::MaxPolicy;
575+
using device_segmented_sort_policy_selector = {11};
576+
using namespace cub;
577+
using namespace cub::detail::segmented_sort;
578+
static_assert(
579+
device_segmented_sort_policy_selector()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {12},
580+
"Host generated and JIT compiled policy mismatch");
581+
using device_three_way_partition_policy = {13}::MaxPolicy;
663582
664583
#include <cub/detail/ptx-json/json.cuh>
665-
__device__ consteval auto& segmented_sort_policy_generator() {{
666-
return ptx_json::id<ptx_json::string("device_segmented_sort_policy")>()
667-
= cub::detail::segmented_sort::SegmentedSortPolicyWrapper<device_segmented_sort_policy::ActivePolicy>::EncodedPolicy();
668-
}}
669584
__device__ consteval auto& three_way_partition_policy_generator() {{
670585
return ptx_json::id<ptx_json::string("device_three_way_partition_policy")>()
671586
= cub::detail::three_way_partition::ThreeWayPartitionPolicyWrapper<device_three_way_partition_policy::ActivePolicy>::EncodedPolicy();
@@ -682,8 +597,15 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
682597
end_offset_iterator_src, // 8
683598
large_selector_src, // 9
684599
small_selector_src, // 10
685-
segmented_sort_policy_hub_expr, // 11
686-
three_way_partition_policy_hub_expr); // 12
600+
segmented_sort_policy_expr, // 11
601+
segmented_sort_policy_sel_str.view(), // 12
602+
three_way_partition_policy_hub_expr); // 13
603+
604+
#if false // CCCL_DEBUGGING_SWITCH
605+
fflush(stderr);
606+
printf("\nCODE4NVRTC BEGIN\n%sCODE4NVRTC END\n", final_src.c_str());
607+
fflush(stdout);
608+
#endif
687609

688610
std::vector<const char*> args = {
689611
arch.c_str(),
@@ -695,7 +617,7 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
695617
"-dlto",
696618
"-default-device",
697619
"-DCUB_DISABLE_CDP",
698-
"-DCUB_ENABLE_POLICY_PTX_JSON",
620+
"-DCUB_ENABLE_POLICY_PTX_JSON", // TODO(bgruber): remove after we ported three way partition to the new tuning API
699621
"-std=c++20"};
700622

701623
cccl::detail::extend_args_with_build_config(args, config);
@@ -769,35 +691,23 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
769691
check(cuLibraryGetKernel(
770692
&build_ptr->three_way_partition_kernel, build_ptr->library, three_way_partition_kernel_lowered_name.c_str()));
771693

772-
nlohmann::json runtime_policy =
773-
cub::detail::ptx_json::parse("device_segmented_sort_policy", {result.data.get(), result.size});
774-
775-
using cub::detail::RuntimeRadixSortDownsweepAgentPolicy;
776-
auto large_segment_policy = RuntimeRadixSortDownsweepAgentPolicy::from_json(runtime_policy, "LargeSegmentPolicy");
777-
778-
using cub::detail::RuntimeSubWarpMergeSortAgentPolicy;
779-
auto small_segment_policy = RuntimeSubWarpMergeSortAgentPolicy::from_json(runtime_policy, "SmallSegmentPolicy");
780-
781-
auto medium_segment_policy = RuntimeSubWarpMergeSortAgentPolicy::from_json(runtime_policy, "MediumSegmentPolicy");
782-
783-
int partitioning_threshold = runtime_policy["PartitioningThreshold"].get<int>();
694+
// TODO(bgruber): convert to the new tuning API
784695
nlohmann::json partition_policy =
785696
cub::detail::ptx_json::parse("device_three_way_partition_policy", {result.data.get(), result.size});
786697

787698
using cub::detail::RuntimeThreeWayPartitionAgentPolicy;
788699
auto three_way_partition_policy =
789700
RuntimeThreeWayPartitionAgentPolicy::from_json(partition_policy, "ThreeWayPartitionPolicy");
790701

791-
build_ptr->cc = cc;
702+
build_ptr->cc = cc_major * 10 + cc_minor;
792703
build_ptr->large_segments_selector_op = large_selector_op;
793704
build_ptr->small_segments_selector_op = small_selector_op;
794705
build_ptr->cubin = (void*) result.data.release();
795706
build_ptr->cubin_size = result.size;
796707
build_ptr->key_type = keys_in_it.value_type;
797708
build_ptr->offset_type = cccl_type_info{sizeof(OffsetT), alignof(OffsetT), cccl_type_enum::CCCL_INT64};
798709
// Use the runtime policy extracted via from_json
799-
build_ptr->runtime_policy = new segmented_sort::segmented_sort_runtime_tuning_policy{
800-
large_segment_policy, small_segment_policy, medium_segment_policy, partitioning_threshold};
710+
build_ptr->runtime_policy = new cub::detail::segmented_sort::policy_selector{policy_sel};
801711
build_ptr->partition_runtime_policy = new segmented_sort::partition_runtime_tuning_policy{three_way_partition_policy};
802712
build_ptr->order = sort_order;
803713

@@ -888,35 +798,34 @@ CUresult cccl_device_segmented_sort_impl(
888798
cub::DoubleBuffer<indirect_arg_t> d_values_double_buffer(
889799
*static_cast<indirect_arg_t**>(&val_arg_in), *static_cast<indirect_arg_t**>(&val_arg_out));
890800

891-
auto exec_status = cub::DispatchSegmentedSort<
801+
// TODO(bgruber): remove all template arguments except the first two (the others can be deduced)
802+
auto exec_status = cub::detail::segmented_sort::dispatch<
892803
Order,
804+
OffsetT, // OffsetT
893805
indirect_arg_t, // KeyT
894806
indirect_arg_t, // ValueT
895-
OffsetT, // OffsetT
896807
indirect_iterator_t, // BeginOffsetIteratorT
897808
indirect_iterator_t, // EndOffsetIteratorT
898-
segmented_sort::segmented_sort_runtime_tuning_policy, // PolicyHub
809+
cub::detail::segmented_sort::policy_selector, // PolicySelector
899810
segmented_sort::segmented_sort_kernel_source, // KernelSource
900-
segmented_sort::partition_runtime_tuning_policy, // PartitionPolicyHub
901-
segmented_sort::partition_kernel_source, // PartitionKernelSource
902-
cub::detail::CudaDriverLauncherFactory>:: // KernelLaunchFactory
903-
Dispatch(
904-
d_temp_storage,
811+
segmented_sort::partition_runtime_tuning_policy // PartitionPolicyHub
812+
>(d_temp_storage,
905813
*temp_storage_bytes,
906814
d_keys_double_buffer,
907815
d_values_double_buffer,
908816
num_items,
909817
num_segments,
910-
start_offset_in,
911-
end_offset_in,
818+
indirect_iterator_t{start_offset_in},
819+
indirect_iterator_t{end_offset_in},
912820
is_overwrite_okay,
913821
stream,
914-
/* kernel_source */ {build},
915-
/* partition_kernel_source */ {build},
916-
/* launcher_factory */ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc},
917-
/* policy */ *reinterpret_cast<segmented_sort::segmented_sort_runtime_tuning_policy*>(build.runtime_policy),
918-
/* partition_policy */
919-
*reinterpret_cast<segmented_sort::partition_runtime_tuning_policy*>(build.partition_runtime_policy));
822+
/* policy_selector */
823+
*static_cast<cub::detail::segmented_sort::policy_selector*>(build.runtime_policy),
824+
/* partition_max_policy */
825+
*static_cast<segmented_sort::partition_runtime_tuning_policy*>(build.partition_runtime_policy),
826+
/* kernel_source */ segmented_sort::segmented_sort_kernel_source{build},
827+
/* partition_kernel_source */ segmented_sort::partition_kernel_source{build},
828+
/* launcher_factory */ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc});
920829

921830
*selector = d_keys_double_buffer.selector;
922831
error = static_cast<CUresult>(exec_status);
@@ -998,8 +907,8 @@ try
998907
std::unique_ptr<char[]> small_code(const_cast<char*>(build_ptr->small_segments_selector_op.code));
999908

1000909
// Clean up the runtime policies
1001-
std::unique_ptr<segmented_sort::segmented_sort_runtime_tuning_policy> rtp(
1002-
static_cast<segmented_sort::segmented_sort_runtime_tuning_policy*>(build_ptr->runtime_policy));
910+
std::unique_ptr<cub::detail::segmented_sort::policy_selector> rtp(
911+
static_cast<cub::detail::segmented_sort::policy_selector*>(build_ptr->runtime_policy));
1003912
std::unique_ptr<segmented_sort::partition_runtime_tuning_policy> prtp(
1004913
static_cast<segmented_sort::partition_runtime_tuning_policy*>(build_ptr->partition_runtime_policy));
1005914
check(cuLibraryUnload(build_ptr->library));

0 commit comments

Comments
 (0)