@@ -40,10 +40,7 @@ template <typename Derived, typename InputIt, typename Size, typename OutputIt,
4040_CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl (
4141 thrust::cuda_cub::execution_policy<Derived>& policy, InputIt first, Size num_items, OutputIt result, ScanOp scan_op)
4242{
43- using AccumT = thrust::detail::it_value_t <InputIt>;
44- using Dispatch32 = cub::DispatchScan<InputIt, OutputIt, ScanOp, cub::NullType, std::uint32_t , AccumT>;
45- using Dispatch64 = cub::DispatchScan<InputIt, OutputIt, ScanOp, cub::NullType, std::uint64_t , AccumT>;
46-
43+ using AccumT = thrust::detail::it_value_t <InputIt>;
4744 cudaStream_t stream = thrust::cuda_cub::stream (policy);
4845 cudaError_t status;
4946
@@ -56,10 +53,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
5653 // Determine temporary storage requirements:
5754 size_t tmp_size = 0 ;
5855 {
59- THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2 (
56+ THRUST_UNSIGNED_INDEX_TYPE_DISPATCH (
6057 status,
61- Dispatch32::Dispatch,
62- Dispatch64::Dispatch,
58+ cub::detail::scan::dispatch_with_accum<AccumT>,
6359 num_items,
6460 (nullptr , tmp_size, first, result, scan_op, cub::NullType{}, num_items_fixed, stream));
6561 thrust::cuda_cub::throw_on_error (
@@ -72,10 +68,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
7268 {
7369 // Allocate temporary storage:
7470 thrust::detail::temporary_array<std::uint8_t , Derived> tmp{policy, tmp_size};
75- THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2 (
71+ THRUST_UNSIGNED_INDEX_TYPE_DISPATCH (
7672 status,
77- Dispatch32::Dispatch,
78- Dispatch64::Dispatch,
73+ cub::detail::scan::dispatch_with_accum<AccumT>,
7974 num_items,
8075 (tmp.data ().get (), tmp_size, first, result, scan_op, cub::NullType{}, num_items_fixed, stream));
8176 thrust::cuda_cub::throw_on_error (status, " after dispatching inclusive_scan kernel" );
@@ -96,15 +91,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
9691 InitValueT init,
9792 ScanOp scan_op)
9893{
99- using InputValueT = cub::detail::InputValue<InitValueT>;
100- using ValueT = cub::detail::it_value_t <InputIt>;
101- using AccumT = ::cuda::std::__accumulator_t <ScanOp, ValueT, InitValueT>;
102-
103- using Dispatch32 =
104- cub::DispatchScan<InputIt, OutputIt, ScanOp, InputValueT, std::uint32_t , AccumT, cub::ForceInclusive::Yes>;
105- using Dispatch64 =
106- cub::DispatchScan<InputIt, OutputIt, ScanOp, InputValueT, std::uint64_t , AccumT, cub::ForceInclusive::Yes>;
107-
94+ using InputValueT = cub::detail::InputValue<InitValueT>;
95+ using ValueT = cub::detail::it_value_t <InputIt>;
96+ using AccumT = ::cuda::std::__accumulator_t <ScanOp, ValueT, InitValueT>;
10897 cudaStream_t stream = thrust::cuda_cub::stream (policy);
10998 cudaError_t status;
11099
@@ -117,10 +106,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
117106 // Determine temporary storage requirements:
118107 size_t tmp_size = 0 ;
119108 {
120- THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2 (
109+ THRUST_UNSIGNED_INDEX_TYPE_DISPATCH (
121110 status,
122- Dispatch32::Dispatch,
123- Dispatch64::Dispatch,
111+ (cub::detail::scan::dispatch_with_accum<AccumT, cub::ForceInclusive::Yes>),
124112 num_items,
125113 (nullptr , tmp_size, first, result, scan_op, InputValueT (init), num_items_fixed, stream));
126114 thrust::cuda_cub::throw_on_error (
@@ -133,10 +121,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
133121 {
134122 // Allocate temporary storage:
135123 thrust::detail::temporary_array<std::uint8_t , Derived> tmp{policy, tmp_size};
136- THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2 (
124+ THRUST_UNSIGNED_INDEX_TYPE_DISPATCH (
137125 status,
138- Dispatch32::Dispatch,
139- Dispatch64::Dispatch,
126+ (cub::detail::scan::dispatch_with_accum<AccumT, cub::ForceInclusive::Yes>),
140127 num_items,
141128 (tmp.data ().get (), tmp_size, first, result, scan_op, InputValueT (init), num_items_fixed, stream));
142129 thrust::cuda_cub::throw_on_error (status, " after dispatching inclusive_scan kernel" );
@@ -157,10 +144,7 @@ _CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl(
157144 InitValueT init,
158145 ScanOp scan_op)
159146{
160- using InputValueT = cub::detail::InputValue<InitValueT>;
161- using Dispatch32 = cub::DispatchScan<InputIt, OutputIt, ScanOp, InputValueT, std::uint32_t , InitValueT>;
162- using Dispatch64 = cub::DispatchScan<InputIt, OutputIt, ScanOp, InputValueT, std::uint64_t , InitValueT>;
163-
147+ using InputValueT = cub::detail::InputValue<InitValueT>;
164148 cudaStream_t stream = thrust::cuda_cub::stream (policy);
165149 cudaError_t status;
166150
@@ -173,10 +157,9 @@ _CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl(
173157 // Determine temporary storage requirements:
174158 size_t tmp_size = 0 ;
175159 {
176- THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2 (
160+ THRUST_UNSIGNED_INDEX_TYPE_DISPATCH (
177161 status,
178- Dispatch32::Dispatch,
179- Dispatch64::Dispatch,
162+ cub::detail::scan::dispatch_with_accum<InitValueT>,
180163 num_items,
181164 (nullptr , tmp_size, first, result, scan_op, InputValueT (init), num_items_fixed, stream));
182165 thrust::cuda_cub::throw_on_error (
@@ -189,10 +172,9 @@ _CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl(
189172 {
190173 // Allocate temporary storage:
191174 thrust::detail::temporary_array<std::uint8_t , Derived> tmp{policy, tmp_size};
192- THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2 (
175+ THRUST_UNSIGNED_INDEX_TYPE_DISPATCH (
193176 status,
194- Dispatch32::Dispatch,
195- Dispatch64::Dispatch,
177+ cub::detail::scan::dispatch_with_accum<InitValueT>,
196178 num_items,
197179 (tmp.data ().get (), tmp_size, first, result, scan_op, InputValueT (init), num_items_fixed, stream));
198180 thrust::cuda_cub::throw_on_error (status, " after dispatching exclusive_scan kernel" );
0 commit comments