Skip to content

Commit 57f42c8

Browse files
trivialfishcho3
andauthored
[backport] Support latest cccl (dmlc#11504) (dmlc#11583)
Co-authored-by: Philip Hyunsu Cho <[email protected]>
1 parent e645e7d commit 57f42c8

File tree

4 files changed

+53
-17
lines changed

4 files changed

+53
-17
lines changed

src/common/algorithm.cuh

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626

2727
namespace xgboost::common {
2828
namespace detail {
29+
30+
#if CUB_VERSION >= 300000
31+
constexpr auto kCubSortOrderAscending = cub::SortOrder::Ascending;
32+
constexpr auto kCubSortOrderDescending = cub::SortOrder::Descending;
33+
#else
34+
constexpr bool kCubSortOrderAscending = false;
35+
constexpr bool kCubSortOrderDescending = true;
36+
#endif
37+
2938
// Wrapper around cub sort to define is_decending
3039
template <bool IS_DESCENDING, typename KeyT, typename BeginOffsetIteratorT,
3140
typename EndOffsetIteratorT>
@@ -42,8 +51,9 @@ static void DeviceSegmentedRadixSortKeys(CUDAContext const *ctx, void *d_temp_st
4251
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
4352
cub::DoubleBuffer<cub::NullType> d_values;
4453

54+
constexpr auto kCubSortOrder = IS_DESCENDING ? kCubSortOrderDescending : kCubSortOrderAscending;
4555
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
46-
IS_DESCENDING, KeyT, cub::NullType, BeginOffsetIteratorT, EndOffsetIteratorT,
56+
kCubSortOrder, KeyT, cub::NullType, BeginOffsetIteratorT, EndOffsetIteratorT,
4757
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
4858
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
4959
end_bit, false, ctx->Stream(), debug_synchronous)));
@@ -68,21 +78,22 @@ void DeviceSegmentedRadixSortPair(void *d_temp_storage,
6878
CHECK_LE(num_items, std::numeric_limits<OffsetT>::max());
6979
// For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation
7080

81+
constexpr auto kCubSortOrder = descending ? kCubSortOrderDescending : kCubSortOrderAscending;
7182
#if THRUST_MAJOR_VERSION >= 2
7283
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
73-
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
84+
kCubSortOrder, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
7485
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
7586
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
7687
end_bit, false, stream)));
7788
#elif (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13)
7889
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
79-
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
90+
kCubSortOrder, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
8091
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
8192
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
8293
end_bit, false, stream, false)));
8394
#else
8495
dh::safe_cuda(
85-
(cub::DispatchSegmentedRadixSort<descending, KeyT, ValueT, BeginOffsetIteratorT,
96+
(cub::DispatchSegmentedRadixSort<kCubSortOrder, KeyT, ValueT, BeginOffsetIteratorT,
8697
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes,
8798
d_keys, d_values, num_items, num_segments,
8899
d_begin_offsets, d_end_offsets, begin_bit,
@@ -207,47 +218,48 @@ void ArgSort(Context const *ctx, Span<U> keys, Span<IdxT> sorted_idx) {
207218
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
208219
using OffsetT = std::conditional_t<!dh::BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
209220
CHECK_LE(sorted_idx.size(), std::numeric_limits<OffsetT>::max());
221+
210222
if (accending) {
211223
void *d_temp_storage = nullptr;
212224
#if THRUST_MAJOR_VERSION >= 2
213-
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
225+
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
214226
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
215227
cuctx->Stream())));
216228
#else
217-
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
229+
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
218230
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
219231
nullptr, false)));
220232
#endif
221233
dh::TemporaryArray<char> storage(bytes);
222234
d_temp_storage = storage.data().get();
223235
#if THRUST_MAJOR_VERSION >= 2
224-
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
236+
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
225237
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
226238
cuctx->Stream())));
227239
#else
228-
dh::safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, OffsetT>::Dispatch(
240+
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderAscending, KeyT, ValueT, OffsetT>::Dispatch(
229241
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
230242
nullptr, false)));
231243
#endif
232244
} else {
233245
void *d_temp_storage = nullptr;
234246
#if THRUST_MAJOR_VERSION >= 2
235-
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
247+
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
236248
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
237249
cuctx->Stream())));
238250
#else
239-
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
251+
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
240252
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
241253
nullptr, false)));
242254
#endif
243255
dh::TemporaryArray<char> storage(bytes);
244256
d_temp_storage = storage.data().get();
245257
#if THRUST_MAJOR_VERSION >= 2
246-
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
258+
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
247259
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
248260
cuctx->Stream())));
249261
#else
250-
dh::safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, OffsetT>::Dispatch(
262+
dh::safe_cuda((cub::DispatchRadixSort<detail::kCubSortOrderDescending, KeyT, ValueT, OffsetT>::Dispatch(
251263
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, sizeof(KeyT) * 8, false,
252264
nullptr, false)));
253265
#endif
@@ -277,6 +289,10 @@ void CopyIf(CUDAContext const *cuctx, InIt in_first, InIt in_second, OutIt out_f
277289
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename OffsetT>
278290
void InclusiveScan(xgboost::Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out,
279291
ScanOpT scan_op, OffsetT num_items) {
292+
#if CUB_VERSION >= 300000
293+
static_assert(std::is_unsigned_v<OffsetT>, "OffsetT must be unsigned");
294+
static_assert(sizeof(OffsetT) >= 4, "OffsetT must be at least 4 bytes long");
295+
#endif
280296
auto cuctx = ctx->CUDACtx();
281297
std::size_t bytes = 0;
282298
#if THRUST_MAJOR_VERSION >= 2
@@ -304,7 +320,11 @@ void InclusiveScan(xgboost::Context const *ctx, InputIteratorT d_in, OutputItera
304320
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
305321
void InclusiveSum(Context const *ctx, InputIteratorT d_in, OutputIteratorT d_out,
306322
OffsetT num_items) {
323+
#if CUB_VERSION >= 300000
324+
InclusiveScan(ctx, d_in, d_out, cuda::std::plus{}, num_items);
325+
#else
307326
InclusiveScan(ctx, d_in, d_out, cub::Sum{}, num_items);
327+
#endif
308328
}
309329
} // namespace xgboost::common
310330
#endif // XGBOOST_COMMON_ALGORITHM_CUH_

src/common/hist_util.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ __global__ void GetColumnSizeSharedMemKernel(IterSpan<BatchIt> batch_iter,
4545

4646
dh::BlockFill(smem_cs_ptr, out_column_size.size(), 0);
4747

48+
#if CUB_VERSION >= 300000
49+
__syncthreads();
50+
#else
4851
cub::CTA_SYNC();
52+
#endif
4953

5054
auto n = batch_iter.size();
5155

@@ -56,7 +60,11 @@ __global__ void GetColumnSizeSharedMemKernel(IterSpan<BatchIt> batch_iter,
5660
}
5761
}
5862

63+
#if CUB_VERSION >= 300000
64+
__syncthreads();
65+
#else
5966
cub::CTA_SYNC();
67+
#endif
6068

6169
auto out_global_ptr = out_column_size;
6270
for (auto i : dh::BlockStrideRange(static_cast<std::size_t>(0), out_column_size.size())) {

src/tree/gpu_hist/evaluate_splits.cu

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ class EvaluateSplitAgent {
115115
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
116116
GradientPairInt64 bin = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
117117
: GradientPairInt64();
118-
BlockScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
118+
#if CUB_VERSION >= 300000
119+
BlockScanT(temp_storage->scan).ExclusiveScan(bin, bin, cuda::std::plus{}, prefix_op);
120+
#else
121+
BlockScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum{}, prefix_op);
122+
#endif
119123
// Whether the gradient of missing values is put to the left side.
120124
bool missing_left = true;
121125
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
@@ -292,7 +296,11 @@ __global__ __launch_bounds__(kBlockSize) void EvaluateSplitsKernel(
292296
agent.Numerical(&best_split);
293297
}
294298

299+
#if CUB_VERSION >= 300000
300+
__syncthreads();
301+
#else
295302
cub::CTA_SYNC();
303+
#endif
296304
if (threadIdx.x == 0) {
297305
// Record best loss for each feature
298306
out_candidates[blockIdx.x] = best_split;

src/tree/gpu_hist/row_partitioner.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,21 +173,21 @@ void SortPositionBatch(Context const* ctx, common::Span<const PerNodeData<OpData
173173
// the iteration.
174174
auto ret =
175175
cub::DispatchScan<decltype(input_iterator), decltype(discard_write_iterator), IndexFlagOp,
176-
cub::NullType, std::int64_t>::Dispatch(nullptr, n_bytes, input_iterator,
176+
cub::NullType, std::uint64_t>::Dispatch(nullptr, n_bytes, input_iterator,
177177
discard_write_iterator,
178178
IndexFlagOp{}, cub::NullType{},
179-
total_rows,
179+
static_cast<std::uint64_t>(total_rows),
180180
ctx->CUDACtx()->Stream());
181181
dh::safe_cuda(ret);
182182
tmp->resize(n_bytes);
183183
}
184184
n_bytes = tmp->size();
185185
auto ret =
186186
cub::DispatchScan<decltype(input_iterator), decltype(discard_write_iterator), IndexFlagOp,
187-
cub::NullType, std::int64_t>::Dispatch(tmp->data(), n_bytes, input_iterator,
187+
cub::NullType, std::uint64_t>::Dispatch(tmp->data(), n_bytes, input_iterator,
188188
discard_write_iterator,
189189
IndexFlagOp{}, cub::NullType{},
190-
total_rows,
190+
static_cast<std::uint64_t>(total_rows),
191191
ctx->CUDACtx()->Stream());
192192
dh::safe_cuda(ret);
193193

0 commit comments

Comments
 (0)