@@ -114,10 +114,12 @@ private:
114114 {
115115 return detail::dispatch_with_env (
116116 env, [&](auto tuning, void * d_temp_storage, size_t & temp_storage_bytes, cudaStream_t stream) {
117- using segmented_reduce_tuning_t = ::cuda::__call_result_or_t <
117+ using default_policy_selector_t =
118+ detail::segmented_reduce::policy_selector_from_types<AccumT, OffsetT, ReductionOpT>;
119+ using policy_selector_t = ::cuda::std::execution::__query_result_or_t <
120+ decltype (tuning),
118121 detail::segmented_reduce::get_tuning_query_t ,
119- detail::segmented_reduce::policy_selector_from_types<AccumT, OffsetT, ReductionOpT>,
120- decltype (tuning)>;
122+ default_policy_selector_t >;
121123 return detail::segmented_reduce::dispatch<AccumT, OffsetT>(
122124 d_temp_storage,
123125 temp_storage_bytes,
@@ -129,7 +131,7 @@ private:
129131 reduction_op,
130132 initial_value,
131133 stream,
132- segmented_reduce_tuning_t {});
134+ default_policy_selector_t {});
133135 });
134136 }
135137 _CCCL_UNREACHABLE ();
@@ -574,17 +576,16 @@ public:
574576
575577 return detail::dispatch_with_env (
576578 env, [&]([[maybe_unused]] auto tuning, void * d_temp_storage, size_t & temp_storage_bytes, cudaStream_t stream) {
577- return detail::reduce::
578- DispatchFixedSizeSegmentedReduce<InputIteratorT, OutputIteratorT, int , ReductionOpT, T>::Dispatch (
579- d_temp_storage,
580- temp_storage_bytes,
581- d_in,
582- d_out,
583- num_segments,
584- segment_size,
585- reduction_op,
586- initial_value,
587- stream);
579+ return detail::reduce::DispatchFixedSizeSegmentedReduce<InputIteratorT, OutputIteratorT, int , ReductionOpT, T>::
580+ Dispatch (d_temp_storage,
581+ temp_storage_bytes,
582+ d_in,
583+ d_out,
584+ num_segments,
585+ segment_size,
586+ reduction_op,
587+ initial_value,
588+ stream);
588589 });
589590 }
590591
@@ -956,9 +957,16 @@ public:
956957 return detail::dispatch_with_env (
957958 env, [&]([[maybe_unused]] auto tuning, void * d_temp_storage, size_t & temp_storage_bytes, cudaStream_t stream) {
958959 return detail::reduce::
959- DispatchFixedSizeSegmentedReduce<InputIteratorT, OutputIteratorT, int , ::cuda::std::plus<>, output_t >::
960- Dispatch (
961- d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, segment_size, ::cuda::std::plus{}, output_t {}, stream);
960+ DispatchFixedSizeSegmentedReduce<InputIteratorT, OutputIteratorT, int , ::cuda::std::plus<>, output_t >::Dispatch (
961+ d_temp_storage,
962+ temp_storage_bytes,
963+ d_in,
964+ d_out,
965+ num_segments,
966+ segment_size,
967+ ::cuda::std::plus{},
968+ output_t {},
969+ stream);
962970 });
963971 }
964972
@@ -1809,9 +1817,9 @@ public:
18091817
18101818 static_assert (::cuda::std::is_same_v<int , output_key_t >, " Output key type must be int." );
18111819
1812- using arg_index_input_iterator_t = THRUST_NS_QUALIFIER::transform_iterator<
1813- detail::reduce::generate_idx_value<InputIteratorT, output_value_t >,
1814- THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t >>;
1820+ using arg_index_input_iterator_t =
1821+ THRUST_NS_QUALIFIER::transform_iterator< detail::reduce::generate_idx_value<InputIteratorT, output_value_t >,
1822+ THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t >>;
18151823
18161824 arg_index_input_iterator_t d_indexed_in = THRUST_NS_QUALIFIER::make_transform_iterator (
18171825 THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t >{0 },
@@ -2680,9 +2688,9 @@ public:
26802688
26812689 static_assert (::cuda::std::is_same_v<int , output_key_t >, " Output key type must be int." );
26822690
2683- using arg_index_input_iterator_t = THRUST_NS_QUALIFIER::transform_iterator<
2684- detail::reduce::generate_idx_value<InputIteratorT, output_value_t >,
2685- THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t >>;
2691+ using arg_index_input_iterator_t =
2692+ THRUST_NS_QUALIFIER::transform_iterator< detail::reduce::generate_idx_value<InputIteratorT, output_value_t >,
2693+ THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t >>;
26862694
26872695 arg_index_input_iterator_t d_indexed_in = THRUST_NS_QUALIFIER::make_transform_iterator (
26882696 THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t >{0 },
0 commit comments