3737#include < nvrtc/ltoir_list_appender.h>
3838#include < util/build_utils.h>
3939
40- struct device_segmented_sort_policy ;
40+ struct device_segmented_sort_policy_selector ;
4141struct device_three_way_partition_policy ;
4242using OffsetT = ptrdiff_t ;
4343static_assert (std::is_same_v<cub::detail::choose_signed_offset_t <OffsetT>, OffsetT>, " OffsetT must be long" );
@@ -55,15 +55,15 @@ std::string get_device_segmented_sort_fallback_kernel_name(
5555 std::string_view value_t ,
5656 cccl_sort_order_t sort_order)
5757{
58- std::string chained_policy_t ;
59- check (cccl_type_name_from_nvrtc<device_segmented_sort_policy >(&chained_policy_t ));
58+ std::string policy_selector_t ;
59+ check (cccl_type_name_from_nvrtc<device_segmented_sort_policy_selector >(&policy_selector_t ));
6060
6161 std::string offset_t ;
6262 check (cccl_type_name_from_nvrtc<OffsetT>(&offset_t ));
6363
6464 /*
6565 template <SortOrder Order, // 0 (ascending)
66- typename ChainedPolicyT , // 1
66+ typename PolicySelector , // 1
6767 typename KeyT, // 2
6868 typename ValueT, // 3
6969 typename BeginOffsetIteratorT, // 4
@@ -74,7 +74,7 @@ std::string get_device_segmented_sort_fallback_kernel_name(
7474 return std::format (
7575 " cub::detail::segmented_sort::DeviceSegmentedSortFallbackKernel<{0}, {1}, {2}, {3}, {4}, {5}, {6}>" ,
7676 (sort_order == CCCL_ASCENDING) ? " cub::SortOrder::Ascending" : " cub::SortOrder::Descending" ,
77- chained_policy_t , // 0
77+ policy_selector_t , // 0
7878 key_t , // 1
7979 value_t , // 2
8080 start_offset_iterator_t , // 3
@@ -90,7 +90,7 @@ std::string get_device_segmented_sort_kernel_small_name(
9090 cccl_sort_order_t sort_order)
9191{
9292 std::string chained_policy_t ;
93- check (cccl_type_name_from_nvrtc<device_segmented_sort_policy >(&chained_policy_t ));
93+ check (cccl_type_name_from_nvrtc<device_segmented_sort_policy_selector >(&chained_policy_t ));
9494
9595 std::string offset_t ;
9696 check (cccl_type_name_from_nvrtc<OffsetT>(&offset_t ));
@@ -124,7 +124,7 @@ std::string get_device_segmented_sort_kernel_large_name(
124124 cccl_sort_order_t sort_order)
125125{
126126 std::string chained_policy_t ;
127- check (cccl_type_name_from_nvrtc<device_segmented_sort_policy >(&chained_policy_t ));
127+ check (cccl_type_name_from_nvrtc<device_segmented_sort_policy_selector >(&chained_policy_t ));
128128
129129 std::string offset_t ;
130130 check (cccl_type_name_from_nvrtc<OffsetT>(&offset_t ));
@@ -349,97 +349,6 @@ struct partition_kernel_source
349349 }
350350};
351351
352- struct segmented_sort_runtime_tuning_policy
353- {
354- cub::detail::RuntimeRadixSortDownsweepAgentPolicy large_segment;
355- cub::detail::RuntimeSubWarpMergeSortAgentPolicy small_segment;
356- cub::detail::RuntimeSubWarpMergeSortAgentPolicy medium_segment;
357- int partitioning_threshold;
358-
359- auto LargeSegment () const
360- {
361- return large_segment;
362- }
363-
364- auto SmallSegment () const
365- {
366- return small_segment;
367- }
368-
369- auto MediumSegment () const
370- {
371- return medium_segment;
372- }
373-
374- int PartitioningThreshold () const
375- {
376- return partitioning_threshold;
377- }
378-
379- int LargeSegmentRadixBits () const
380- {
381- return large_segment.RadixBits ();
382- }
383-
384- int SegmentsPerSmallBlock () const
385- {
386- return small_segment.SegmentsPerBlock ();
387- }
388-
389- int SegmentsPerMediumBlock () const
390- {
391- return medium_segment.SegmentsPerBlock ();
392- }
393-
394- int SmallPolicyItemsPerTile () const
395- {
396- return small_segment.ItemsPerTile ();
397- }
398-
399- int MediumPolicyItemsPerTile () const
400- {
401- return medium_segment.ItemsPerTile ();
402- }
403-
404- cub::CacheLoadModifier LargeSegmentLoadModifier () const
405- {
406- return large_segment.LoadModifier ();
407- }
408-
409- cub::BlockLoadAlgorithm LargeSegmentLoadAlgorithm () const
410- {
411- return large_segment.LoadAlgorithm ();
412- }
413-
414- cub::WarpLoadAlgorithm MediumSegmentLoadAlgorithm () const
415- {
416- return medium_segment.LoadAlgorithm ();
417- }
418-
419- cub::WarpLoadAlgorithm SmallSegmentLoadAlgorithm () const
420- {
421- return small_segment.LoadAlgorithm ();
422- }
423-
424- cub::WarpStoreAlgorithm MediumSegmentStoreAlgorithm () const
425- {
426- return medium_segment.StoreAlgorithm ();
427- }
428-
429- cub::WarpStoreAlgorithm SmallSegmentStoreAlgorithm () const
430- {
431- return small_segment.StoreAlgorithm ();
432- }
433-
434- using MaxPolicy = segmented_sort_runtime_tuning_policy;
435-
436- template <typename F>
437- cudaError_t Invoke (int , F& op)
438- {
439- return op.template Invoke <segmented_sort_runtime_tuning_policy>(*this );
440- }
441- };
442-
443352struct partition_runtime_tuning_policy
444353{
445354 cub::detail::RuntimeThreeWayPartitionAgentPolicy three_way_partition;
@@ -534,26 +443,22 @@ try
534443
535444 const char * name = " device_segmented_sort" ;
536445
537- const int cc = cc_major * 10 + cc_minor;
538-
539446 const auto [keys_in_iterator_name, keys_in_iterator_src] =
540447 get_specialization<segmented_sort_keys_input_iterator_tag>(template_id<input_iterator_traits>(), keys_in_it);
541448
542449 const bool keys_only = values_in_it.type == cccl_iterator_kind_t ::CCCL_POINTER && values_in_it.state == nullptr ;
543450
544- std::string values_in_iterator_name, values_in_iterator_src;
451+ std::string values_in_iterator_src;
545452
546453 if (!keys_only)
547454 {
548455 const auto [vi_name, vi_src] =
549456 get_specialization<segmented_sort_values_input_iterator_tag>(template_id<input_iterator_traits>(), values_in_it);
550- values_in_iterator_name = vi_name;
551- values_in_iterator_src = vi_src;
457+ values_in_iterator_src = vi_src;
552458 }
553459 else
554460 {
555- values_in_iterator_name = " cub::NullType*" ;
556- values_in_iterator_src = " " ;
461+ values_in_iterator_src = " " ;
557462 }
558463
559464 const auto [start_offset_iterator_name, start_offset_iterator_src] =
624529 const auto [small_selector_name, small_selector_src] = get_specialization<segmented_sort_small_selector_tag>(
625530 template_id<user_operation_traits>(), small_selector_op, selector_result_t , selector_input_t );
626531
627- const auto segmented_sort_policy_hub_expr = std::format (
628- " cub::detail::segmented_sort::policy_hub<{0}, {1}>" ,
532+ const auto policy_sel = cub::detail::segmented_sort::policy_selector{
533+ static_cast <int >(keys_in_it.value_type .size ),
534+ keys_only ? int {sizeof (cub::NullType)} : static_cast <int >(values_in_it.value_type .size ),
535+ keys_only};
536+
537+ // TODO(bgruber): drop this if tuning policies become formattable
538+ std::stringstream segmented_sort_policy_sel_str;
539+ segmented_sort_policy_sel_str << policy_sel (cuda::to_arch_id (cuda::compute_capability{cc_major, cc_minor}));
540+
541+ const auto segmented_sort_policy_expr = std::format (
542+ " cub::detail::segmented_sort::policy_selector_from_types<{}, {}>" ,
629543 key_t , // 0
630544 value_t ); // 1
631545
@@ -658,14 +572,15 @@ struct __align__({4}) items_storage_t {{
658572{8}
659573{9}
660574{10}
661- using device_segmented_sort_policy = {11}::MaxPolicy;
662- using device_three_way_partition_policy = {12}::MaxPolicy;
575+ using device_segmented_sort_policy_selector = {11};
576+ using namespace cub;
577+ using namespace cub::detail::segmented_sort;
578+ static_assert(
579+ device_segmented_sort_policy_selector()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {12},
580+ "Host generated and JIT compiled policy mismatch");
581+ using device_three_way_partition_policy = {13}::MaxPolicy;
663582
664583#include <cub/detail/ptx-json/json.cuh>
665- __device__ consteval auto& segmented_sort_policy_generator() {{
666- return ptx_json::id<ptx_json::string("device_segmented_sort_policy")>()
667- = cub::detail::segmented_sort::SegmentedSortPolicyWrapper<device_segmented_sort_policy::ActivePolicy>::EncodedPolicy();
668- }}
669584__device__ consteval auto& three_way_partition_policy_generator() {{
670585 return ptx_json::id<ptx_json::string("device_three_way_partition_policy")>()
671586 = cub::detail::three_way_partition::ThreeWayPartitionPolicyWrapper<device_three_way_partition_policy::ActivePolicy>::EncodedPolicy();
@@ -682,8 +597,15 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
682597 end_offset_iterator_src, // 8
683598 large_selector_src, // 9
684599 small_selector_src, // 10
685- segmented_sort_policy_hub_expr, // 11
686- three_way_partition_policy_hub_expr); // 12
600+ segmented_sort_policy_expr, // 11
601+ segmented_sort_policy_sel_str.view (), // 12
602+ three_way_partition_policy_hub_expr); // 13
603+
604+ #if false // CCCL_DEBUGGING_SWITCH
605+ fflush (stderr);
606+ printf (" \n CODE4NVRTC BEGIN\n %sCODE4NVRTC END\n " , final_src.c_str ());
607+ fflush (stdout);
608+ #endif
687609
688610 std::vector<const char *> args = {
689611 arch.c_str (),
@@ -695,7 +617,7 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
695617 " -dlto" ,
696618 " -default-device" ,
697619 " -DCUB_DISABLE_CDP" ,
698- " -DCUB_ENABLE_POLICY_PTX_JSON" ,
620+ " -DCUB_ENABLE_POLICY_PTX_JSON" , // TODO(bgruber): remove after we ported three way partition to the new tuning API
699621 " -std=c++20" };
700622
701623 cccl::detail::extend_args_with_build_config (args, config);
@@ -769,35 +691,23 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
769691 check (cuLibraryGetKernel (
770692 &build_ptr->three_way_partition_kernel , build_ptr->library , three_way_partition_kernel_lowered_name.c_str ()));
771693
772- nlohmann::json runtime_policy =
773- cub::detail::ptx_json::parse (" device_segmented_sort_policy" , {result.data .get (), result.size });
774-
775- using cub::detail::RuntimeRadixSortDownsweepAgentPolicy;
776- auto large_segment_policy = RuntimeRadixSortDownsweepAgentPolicy::from_json (runtime_policy, " LargeSegmentPolicy" );
777-
778- using cub::detail::RuntimeSubWarpMergeSortAgentPolicy;
779- auto small_segment_policy = RuntimeSubWarpMergeSortAgentPolicy::from_json (runtime_policy, " SmallSegmentPolicy" );
780-
781- auto medium_segment_policy = RuntimeSubWarpMergeSortAgentPolicy::from_json (runtime_policy, " MediumSegmentPolicy" );
782-
783- int partitioning_threshold = runtime_policy[" PartitioningThreshold" ].get <int >();
694+ // TODO(bgruber): convert to the new tuning API
784695 nlohmann::json partition_policy =
785696 cub::detail::ptx_json::parse (" device_three_way_partition_policy" , {result.data .get (), result.size });
786697
787698 using cub::detail::RuntimeThreeWayPartitionAgentPolicy;
788699 auto three_way_partition_policy =
789700 RuntimeThreeWayPartitionAgentPolicy::from_json (partition_policy, " ThreeWayPartitionPolicy" );
790701
791- build_ptr->cc = cc ;
702+ build_ptr->cc = cc_major * 10 + cc_minor ;
792703 build_ptr->large_segments_selector_op = large_selector_op;
793704 build_ptr->small_segments_selector_op = small_selector_op;
794705 build_ptr->cubin = (void *) result.data .release ();
795706 build_ptr->cubin_size = result.size ;
796707 build_ptr->key_type = keys_in_it.value_type ;
797708 build_ptr->offset_type = cccl_type_info{sizeof (OffsetT), alignof (OffsetT), cccl_type_enum::CCCL_INT64};
798709 // Use the runtime policy extracted via from_json
799- build_ptr->runtime_policy = new segmented_sort::segmented_sort_runtime_tuning_policy{
800- large_segment_policy, small_segment_policy, medium_segment_policy, partitioning_threshold};
710+ build_ptr->runtime_policy = new cub::detail::segmented_sort::policy_selector{policy_sel};
801711 build_ptr->partition_runtime_policy = new segmented_sort::partition_runtime_tuning_policy{three_way_partition_policy};
802712 build_ptr->order = sort_order;
803713
@@ -888,35 +798,34 @@ CUresult cccl_device_segmented_sort_impl(
888798 cub::DoubleBuffer<indirect_arg_t > d_values_double_buffer (
889799 *static_cast <indirect_arg_t **>(&val_arg_in), *static_cast <indirect_arg_t **>(&val_arg_out));
890800
891- auto exec_status = cub::DispatchSegmentedSort<
801+ // TODO(bgruber): remove all template arguments except the first two (the others can be deduced)
802+ auto exec_status = cub::detail::segmented_sort::dispatch<
892803 Order,
804+ OffsetT, // OffsetT
893805 indirect_arg_t , // KeyT
894806 indirect_arg_t , // ValueT
895- OffsetT, // OffsetT
896807 indirect_iterator_t , // BeginOffsetIteratorT
897808 indirect_iterator_t , // EndOffsetIteratorT
898- segmented_sort::segmented_sort_runtime_tuning_policy , // PolicyHub
809+ cub::detail:: segmented_sort::policy_selector , // PolicySelector
899810 segmented_sort::segmented_sort_kernel_source, // KernelSource
900- segmented_sort::partition_runtime_tuning_policy, // PartitionPolicyHub
901- segmented_sort::partition_kernel_source, // PartitionKernelSource
902- cub::detail::CudaDriverLauncherFactory>:: // KernelLaunchFactory
903- Dispatch (
904- d_temp_storage,
811+ segmented_sort::partition_runtime_tuning_policy // PartitionPolicyHub
812+ >(d_temp_storage,
905813 *temp_storage_bytes,
906814 d_keys_double_buffer,
907815 d_values_double_buffer,
908816 num_items,
909817 num_segments,
910- start_offset_in,
911- end_offset_in,
818+ indirect_iterator_t { start_offset_in} ,
819+ indirect_iterator_t { end_offset_in} ,
912820 is_overwrite_okay,
913821 stream,
914- /* kernel_source */ {build},
915- /* partition_kernel_source */ {build},
916- /* launcher_factory */ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc },
917- /* policy */ *reinterpret_cast <segmented_sort::segmented_sort_runtime_tuning_policy*>(build.runtime_policy ),
918- /* partition_policy */
919- *reinterpret_cast <segmented_sort::partition_runtime_tuning_policy*>(build.partition_runtime_policy ));
822+ /* policy_selector */
823+ *static_cast <cub::detail::segmented_sort::policy_selector*>(build.runtime_policy ),
824+ /* partition_max_policy */
825+ *static_cast <segmented_sort::partition_runtime_tuning_policy*>(build.partition_runtime_policy ),
826+ /* kernel_source */ segmented_sort::segmented_sort_kernel_source{build},
827+ /* partition_kernel_source */ segmented_sort::partition_kernel_source{build},
828+ /* launcher_factory */ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc });
920829
921830 *selector = d_keys_double_buffer.selector ;
922831 error = static_cast <CUresult>(exec_status);
998907 std::unique_ptr<char []> small_code (const_cast <char *>(build_ptr->small_segments_selector_op .code ));
999908
1000909 // Clean up the runtime policies
1001- std::unique_ptr<segmented_sort::segmented_sort_runtime_tuning_policy > rtp (
1002- static_cast <segmented_sort::segmented_sort_runtime_tuning_policy *>(build_ptr->runtime_policy ));
910+ std::unique_ptr<cub::detail:: segmented_sort::policy_selector > rtp (
911+ static_cast <cub::detail:: segmented_sort::policy_selector *>(build_ptr->runtime_policy ));
1003912 std::unique_ptr<segmented_sort::partition_runtime_tuning_policy> prtp (
1004913 static_cast <segmented_sort::partition_runtime_tuning_policy*>(build_ptr->partition_runtime_policy ));
1005914 check (cuLibraryUnload (build_ptr->library ));
0 commit comments