Skip to content

Commit 44d9d80

Browse files
committed
Use __query_result_or_t to query tuning environment
1 parent e78a0d7 commit 44d9d80

File tree

1 file changed

+32
-24
lines changed

1 file changed

+32
-24
lines changed

cub/cub/device/device_segmented_reduce.cuh

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,12 @@ private:
114114
{
115115
return detail::dispatch_with_env(
116116
env, [&](auto tuning, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) {
117-
using segmented_reduce_tuning_t = ::cuda::__call_result_or_t<
117+
using default_policy_selector_t =
118+
detail::segmented_reduce::policy_selector_from_types<AccumT, OffsetT, ReductionOpT>;
119+
using policy_selector_t = ::cuda::std::execution::__query_result_or_t<
120+
decltype(tuning),
118121
detail::segmented_reduce::get_tuning_query_t,
119-
detail::segmented_reduce::policy_selector_from_types<AccumT, OffsetT, ReductionOpT>,
120-
decltype(tuning)>;
122+
default_policy_selector_t>;
121123
return detail::segmented_reduce::dispatch<AccumT, OffsetT>(
122124
d_temp_storage,
123125
temp_storage_bytes,
@@ -129,7 +131,7 @@ private:
129131
reduction_op,
130132
initial_value,
131133
stream,
132-
segmented_reduce_tuning_t{});
134+
default_policy_selector_t{});
133135
});
134136
}
135137
_CCCL_UNREACHABLE();
@@ -574,17 +576,16 @@ public:
574576

575577
return detail::dispatch_with_env(
576578
env, [&]([[maybe_unused]] auto tuning, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) {
577-
return detail::reduce::
578-
DispatchFixedSizeSegmentedReduce<InputIteratorT, OutputIteratorT, int, ReductionOpT, T>::Dispatch(
579-
d_temp_storage,
580-
temp_storage_bytes,
581-
d_in,
582-
d_out,
583-
num_segments,
584-
segment_size,
585-
reduction_op,
586-
initial_value,
587-
stream);
579+
return detail::reduce::DispatchFixedSizeSegmentedReduce<InputIteratorT, OutputIteratorT, int, ReductionOpT, T>::
580+
Dispatch(d_temp_storage,
581+
temp_storage_bytes,
582+
d_in,
583+
d_out,
584+
num_segments,
585+
segment_size,
586+
reduction_op,
587+
initial_value,
588+
stream);
588589
});
589590
}
590591

@@ -956,9 +957,16 @@ public:
956957
return detail::dispatch_with_env(
957958
env, [&]([[maybe_unused]] auto tuning, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) {
958959
return detail::reduce::
959-
DispatchFixedSizeSegmentedReduce<InputIteratorT, OutputIteratorT, int, ::cuda::std::plus<>, output_t>::
960-
Dispatch(
961-
d_temp_storage, temp_storage_bytes, d_in, d_out, num_segments, segment_size, ::cuda::std::plus{}, output_t{}, stream);
960+
DispatchFixedSizeSegmentedReduce<InputIteratorT, OutputIteratorT, int, ::cuda::std::plus<>, output_t>::Dispatch(
961+
d_temp_storage,
962+
temp_storage_bytes,
963+
d_in,
964+
d_out,
965+
num_segments,
966+
segment_size,
967+
::cuda::std::plus{},
968+
output_t{},
969+
stream);
962970
});
963971
}
964972

@@ -1809,9 +1817,9 @@ public:
18091817

18101818
static_assert(::cuda::std::is_same_v<int, output_key_t>, "Output key type must be int.");
18111819

1812-
using arg_index_input_iterator_t = THRUST_NS_QUALIFIER::transform_iterator<
1813-
detail::reduce::generate_idx_value<InputIteratorT, output_value_t>,
1814-
THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t>>;
1820+
using arg_index_input_iterator_t =
1821+
THRUST_NS_QUALIFIER::transform_iterator<detail::reduce::generate_idx_value<InputIteratorT, output_value_t>,
1822+
THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t>>;
18151823

18161824
arg_index_input_iterator_t d_indexed_in = THRUST_NS_QUALIFIER::make_transform_iterator(
18171825
THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t>{0},
@@ -2680,9 +2688,9 @@ public:
26802688

26812689
static_assert(::cuda::std::is_same_v<int, output_key_t>, "Output key type must be int.");
26822690

2683-
using arg_index_input_iterator_t = THRUST_NS_QUALIFIER::transform_iterator<
2684-
detail::reduce::generate_idx_value<InputIteratorT, output_value_t>,
2685-
THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t>>;
2691+
using arg_index_input_iterator_t =
2692+
THRUST_NS_QUALIFIER::transform_iterator<detail::reduce::generate_idx_value<InputIteratorT, output_value_t>,
2693+
THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t>>;
26862694

26872695
arg_index_input_iterator_t d_indexed_in = THRUST_NS_QUALIFIER::make_transform_iterator(
26882696
THRUST_NS_QUALIFIER::counting_iterator<::cuda::std::int64_t>{0},

0 commit comments

Comments
 (0)