26
26
27
27
namespace xgboost ::common {
28
28
namespace detail {
29
+
30
+ #if CUB_VERSION >= 300000
31
+ constexpr auto kCubSortOrderAscending = cub::SortOrder::Ascending;
32
+ constexpr auto kCubSortOrderDescending = cub::SortOrder::Descending;
33
+ #else
34
+ constexpr bool kCubSortOrderAscending = false ;
35
+ constexpr bool kCubSortOrderDescending = true ;
36
+ #endif
37
+
29
38
// Wrapper around cub sort to define is_decending
30
39
template <bool IS_DESCENDING, typename KeyT, typename BeginOffsetIteratorT,
31
40
typename EndOffsetIteratorT>
@@ -42,8 +51,9 @@ static void DeviceSegmentedRadixSortKeys(CUDAContext const *ctx, void *d_temp_st
42
51
cub::DoubleBuffer<KeyT> d_keys (const_cast <KeyT *>(d_keys_in), d_keys_out);
43
52
cub::DoubleBuffer<cub::NullType> d_values;
44
53
54
+ constexpr auto kCubSortOrder = IS_DESCENDING ? kCubSortOrderDescending : kCubSortOrderAscending ;
45
55
dh::safe_cuda ((cub::DispatchSegmentedRadixSort<
46
- IS_DESCENDING , KeyT, cub::NullType, BeginOffsetIteratorT, EndOffsetIteratorT,
56
+ kCubSortOrder , KeyT, cub::NullType, BeginOffsetIteratorT, EndOffsetIteratorT,
47
57
OffsetT>::Dispatch (d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
48
58
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
49
59
end_bit, false , ctx->Stream (), debug_synchronous)));
@@ -68,21 +78,22 @@ void DeviceSegmentedRadixSortPair(void *d_temp_storage,
68
78
CHECK_LE (num_items, std::numeric_limits<OffsetT>::max ());
69
79
// For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation
70
80
81
+ constexpr auto kCubSortOrder = descending ? kCubSortOrderDescending : kCubSortOrderAscending ;
71
82
#if THRUST_MAJOR_VERSION >= 2
72
83
dh::safe_cuda ((cub::DispatchSegmentedRadixSort<
73
- descending , KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
84
+ kCubSortOrder , KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
74
85
OffsetT>::Dispatch (d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
75
86
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
76
87
end_bit, false , stream)));
77
88
#elif (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13)
78
89
dh::safe_cuda ((cub::DispatchSegmentedRadixSort<
79
- descending , KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
90
+ kCubSortOrder , KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
80
91
OffsetT>::Dispatch (d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
81
92
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
82
93
end_bit, false , stream, false )));
83
94
#else
84
95
dh::safe_cuda (
85
- (cub::DispatchSegmentedRadixSort<descending , KeyT, ValueT, BeginOffsetIteratorT,
96
+ (cub::DispatchSegmentedRadixSort<kCubSortOrder , KeyT, ValueT, BeginOffsetIteratorT,
86
97
OffsetT>::Dispatch (d_temp_storage, temp_storage_bytes,
87
98
d_keys, d_values, num_items, num_segments,
88
99
d_begin_offsets, d_end_offsets, begin_bit,
@@ -207,47 +218,48 @@ void ArgSort(Context const *ctx, Span<U> keys, Span<IdxT> sorted_idx) {
207
218
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
208
219
using OffsetT = std::conditional_t <!dh::BuildWithCUDACub (), std::ptrdiff_t , int32_t >;
209
220
CHECK_LE (sorted_idx.size (), std::numeric_limits<OffsetT>::max ());
221
+
210
222
if (accending) {
211
223
void *d_temp_storage = nullptr ;
212
224
#if THRUST_MAJOR_VERSION >= 2
213
- dh::safe_cuda ((cub::DispatchRadixSort<false , KeyT, ValueT, OffsetT>::Dispatch (
225
+ dh::safe_cuda ((cub::DispatchRadixSort<detail:: kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
214
226
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
215
227
cuctx->Stream ())));
216
228
#else
217
- dh::safe_cuda ((cub::DispatchRadixSort<false , KeyT, ValueT, OffsetT>::Dispatch (
229
+ dh::safe_cuda ((cub::DispatchRadixSort<detail:: kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
218
230
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
219
231
nullptr , false )));
220
232
#endif
221
233
dh::TemporaryArray<char > storage (bytes);
222
234
d_temp_storage = storage.data ().get ();
223
235
#if THRUST_MAJOR_VERSION >= 2
224
- dh::safe_cuda ((cub::DispatchRadixSort<false , KeyT, ValueT, OffsetT>::Dispatch (
236
+ dh::safe_cuda ((cub::DispatchRadixSort<detail:: kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
225
237
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
226
238
cuctx->Stream ())));
227
239
#else
228
- dh::safe_cuda ((cub::DispatchRadixSort<false , KeyT, ValueT, OffsetT>::Dispatch (
240
+ dh::safe_cuda ((cub::DispatchRadixSort<detail:: kCubSortOrderAscending , KeyT, ValueT, OffsetT>::Dispatch (
229
241
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
230
242
nullptr , false )));
231
243
#endif
232
244
} else {
233
245
void *d_temp_storage = nullptr ;
234
246
#if THRUST_MAJOR_VERSION >= 2
235
- dh::safe_cuda ((cub::DispatchRadixSort<true , KeyT, ValueT, OffsetT>::Dispatch (
247
+ dh::safe_cuda ((cub::DispatchRadixSort<detail:: kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
236
248
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
237
249
cuctx->Stream ())));
238
250
#else
239
- dh::safe_cuda ((cub::DispatchRadixSort<true , KeyT, ValueT, OffsetT>::Dispatch (
251
+ dh::safe_cuda ((cub::DispatchRadixSort<detail:: kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
240
252
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
241
253
nullptr , false )));
242
254
#endif
243
255
dh::TemporaryArray<char > storage (bytes);
244
256
d_temp_storage = storage.data ().get ();
245
257
#if THRUST_MAJOR_VERSION >= 2
246
- dh::safe_cuda ((cub::DispatchRadixSort<true , KeyT, ValueT, OffsetT>::Dispatch (
258
+ dh::safe_cuda ((cub::DispatchRadixSort<detail:: kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
247
259
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
248
260
cuctx->Stream ())));
249
261
#else
250
- dh::safe_cuda ((cub::DispatchRadixSort<true , KeyT, ValueT, OffsetT>::Dispatch (
262
+ dh::safe_cuda ((cub::DispatchRadixSort<detail:: kCubSortOrderDescending , KeyT, ValueT, OffsetT>::Dispatch (
251
263
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size (), 0 , sizeof (KeyT) * 8 , false ,
252
264
nullptr , false )));
253
265
#endif
@@ -277,6 +289,10 @@ void CopyIf(CUDAContext const *cuctx, InIt in_first, InIt in_second, OutIt out_f
277
289
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename OffsetT>
278
290
void InclusiveScan (xgboost::Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out,
279
291
ScanOpT scan_op, OffsetT num_items) {
292
+ #if CUB_VERSION >= 300000
293
+ static_assert (std::is_unsigned_v<OffsetT>, " OffsetT must be unsigned" );
294
+ static_assert (sizeof (OffsetT) >= 4 , " OffsetT must be at least 4 bytes long" );
295
+ #endif
280
296
auto cuctx = ctx->CUDACtx ();
281
297
std::size_t bytes = 0 ;
282
298
#if THRUST_MAJOR_VERSION >= 2
@@ -304,7 +320,11 @@ void InclusiveScan(xgboost::Context const *ctx, InputIteratorT d_in, OutputItera
304
320
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
305
321
void InclusiveSum (Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out,
306
322
OffsetT num_items) {
323
+ #if CUB_VERSION >= 300000
324
+ InclusiveScan (ctx, d_in, d_out, cuda::std::plus{}, num_items);
325
+ #else
307
326
InclusiveScan (ctx, d_in, d_out, cub::Sum{}, num_items);
327
+ #endif
308
328
}
309
329
} // namespace xgboost::common
310
330
#endif // XGBOOST_COMMON_ALGORITHM_CUH_
0 commit comments