1
1
/* *
2
- * Copyright 2017-2024 , XGBoost contributors
2
+ * Copyright 2017-2025 , XGBoost contributors
3
3
*/
4
4
#pragma once
5
- #include < thrust/execution_policy.h>
6
5
#include < thrust/iterator/counting_iterator.h> // for make_counting_iterator
7
6
#include < thrust/iterator/transform_output_iterator.h> // for make_transform_output_iterator
8
7
9
- #include < algorithm> // for max
10
- #include < cstddef> // for size_t
11
- #include < cstdint> // for int32_t, uint32_t
12
- #include < vector> // for vector
8
+ #include < algorithm> // for max
9
+ #include < cstddef> // for size_t
10
+ #include < cstdint> // for int32_t, uint32_t
11
+ #include < cuda/functional> // for proclaim_return_type
12
+ #include < vector> // for vector
13
13
14
14
#include " ../../common/cuda_context.cuh" // for CUDAContext
15
15
#include " ../../common/device_helpers.cuh" // for MakeTransformIterator
@@ -21,7 +21,7 @@ namespace xgboost::tree {
21
21
namespace cuda_impl {
22
22
using RowIndexT = std::uint32_t ;
23
23
// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
24
- static const std::int32_t kMaxUpdatePositionBatchSize = 32 ;
24
+ inline constexpr std::int32_t kMaxUpdatePositionBatchSize = 32 ;
25
25
} // namespace cuda_impl
26
26
27
27
/* *
@@ -37,7 +37,7 @@ struct Segment {
37
37
Segment (cuda_impl::RowIndexT begin, cuda_impl::RowIndexT end) : begin(begin), end(end) {
38
38
CHECK_GE (end, begin);
39
39
}
40
- __host__ __device__ bst_idx_t Size () const { return end - begin; }
40
+ [[nodiscard]] XGBOOST_DEVICE bst_idx_t Size () const { return end - begin; }
41
41
};
42
42
43
43
template <typename OpDataT>
@@ -46,28 +46,42 @@ struct PerNodeData {
46
46
OpDataT data;
47
47
};
48
48
49
- template <typename BatchIterT>
50
- XGBOOST_DEV_INLINE void AssignBatch (BatchIterT batch_info, std::size_t global_thread_idx,
51
- int * batch_idx, std::size_t * item_idx) {
49
+ /* *
50
+ * @param global_thread_idx In practice, the row index within the total number of rows for
51
+ * this node batch.
52
+ * @param batch_idx The nidx within this node batch (not the actual node index in a tree).
53
+ * @param item_idx The resulting global row index (without accounting for base_rowid). This maps the
54
+ * row index within the node batch back to the global row index.
55
+ */
56
+ template <typename T>
57
+ XGBOOST_DEV_INLINE void AssignBatch (dh::LDGIterator<T> const & batch_info_iter,
58
+ std::size_t global_thread_idx, int * batch_idx,
59
+ std::size_t * item_idx) {
52
60
cuda_impl::RowIndexT sum = 0 ;
53
- for (int i = 0 ; i < cuda_impl::kMaxUpdatePositionBatchSize ; i++) {
54
- if (sum + batch_info[i].segment .Size () > global_thread_idx) {
61
+ // Search for the nidx in batch and the corresponding global row index, exit once found.
62
+ for (std::int32_t i = 0 ; i < cuda_impl::kMaxUpdatePositionBatchSize ; i++) {
63
+ if (sum + batch_info_iter[i].segment .Size () > global_thread_idx) {
55
64
*batch_idx = i;
56
- *item_idx = (global_thread_idx - sum) + batch_info[i].segment .begin ;
65
+ // the beginning of the segment plus the offset into that segment
66
+ *item_idx = (global_thread_idx - sum) + batch_info_iter[i].segment .begin ;
57
67
break ;
58
68
}
59
- sum += batch_info [i].segment .Size ();
69
+ sum += batch_info_iter [i].segment .Size ();
60
70
}
61
71
}
62
72
73
+ /* *
74
+ * @param total_rows The total number of rows for this batch of nodes.
75
+ */
63
76
template <int kBlockSize , typename OpDataT>
64
77
__global__ __launch_bounds__ (kBlockSize ) void SortPositionCopyKernel(
65
- dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<cuda_impl::RowIndexT> d_ridx,
66
- const common::Span<const cuda_impl::RowIndexT> ridx_tmp, bst_idx_t total_rows) {
78
+ dh::LDGIterator<PerNodeData<OpDataT>> batch_info_iter,
79
+ common::Span<cuda_impl::RowIndexT> d_ridx,
80
+ common::Span<cuda_impl::RowIndexT const > const ridx_tmp, bst_idx_t total_rows) {
67
81
for (auto idx : dh::GridStrideRange<std::size_t >(0 , total_rows)) {
68
- int batch_idx;
69
- std::size_t item_idx;
70
- AssignBatch (batch_info , idx, &batch_idx, &item_idx);
82
+ std:: int32_t batch_idx; // unused
83
+ std::size_t item_idx = std::numeric_limits<std:: size_t >:: max () ;
84
+ AssignBatch (batch_info_iter , idx, &batch_idx, &item_idx);
71
85
d_ridx[item_idx] = ridx_tmp[item_idx];
72
86
}
73
87
}
@@ -141,18 +155,22 @@ void SortPositionBatch(Context const* ctx, common::Span<const PerNodeData<OpData
141
155
auto discard_write_iterator =
142
156
thrust::make_transform_output_iterator (dh::TypedDiscard<IndexFlagTuple>(), write_results);
143
157
auto counting = thrust::make_counting_iterator (0llu);
144
- auto input_iterator =
145
- dh::MakeTransformIterator <IndexFlagTuple>(counting, [=] __device__ (std::size_t idx) {
146
- int nidx_in_batch;
158
+ auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
159
+ counting, cuda::proclaim_return_type <IndexFlagTuple>([=] __device__ (std::size_t idx) {
160
+ std:: int32_t nidx_in_batch;
147
161
std::size_t item_idx;
148
162
AssignBatch (batch_info_itr, idx, &nidx_in_batch, &item_idx);
149
163
auto go_left = op (ridx[item_idx], nidx_in_batch, batch_info_itr[nidx_in_batch].data );
150
164
return IndexFlagTuple{static_cast <cuda_impl::RowIndexT>(item_idx), go_left, nidx_in_batch,
151
165
go_left};
152
- });
153
- // Avoid using int as the offset type
166
+ })) ;
167
+ // Reach down to the dispatch function to avoid using int as the offset type.
154
168
std::size_t n_bytes = 0 ;
155
169
if (tmp->empty ()) {
170
+ // The size of temporary storage is calculated based on the total number of
171
+ // rows. Since the root node has all the rows, subsequence allocatioin must be smaller
172
+ // than the root node. As a result, we can calculate this once and reuse it throughout
173
+ // the iteration.
156
174
auto ret =
157
175
cub::DispatchScan<decltype (input_iterator), decltype (discard_write_iterator), IndexFlagOp,
158
176
cub::NullType, std::int64_t >::Dispatch (nullptr , n_bytes, input_iterator,
@@ -305,10 +323,10 @@ class RowPartitioner {
305
323
* second. Returns true if this training instance goes on the left partition.
306
324
*/
307
325
template <typename UpdatePositionOpT, typename OpDataT>
308
- void UpdatePositionBatch (Context const * ctx, const std::vector<bst_node_t >& nidx,
309
- const std::vector<bst_node_t >& left_nidx,
310
- const std::vector<bst_node_t >& right_nidx,
311
- const std::vector<OpDataT>& op_data, UpdatePositionOpT op) {
326
+ void UpdatePositionBatch (Context const * ctx, std::vector<bst_node_t > const & nidx,
327
+ std::vector<bst_node_t > const & left_nidx,
328
+ std::vector<bst_node_t > const & right_nidx,
329
+ std::vector<OpDataT> const & op_data, UpdatePositionOpT op) {
312
330
if (nidx.empty ()) {
313
331
return ;
314
332
}
@@ -317,28 +335,47 @@ class RowPartitioner {
317
335
CHECK_EQ (nidx.size (), right_nidx.size ());
318
336
CHECK_EQ (nidx.size (), op_data.size ());
319
337
this ->n_nodes_ += (left_nidx.size () + right_nidx.size ());
320
-
321
- auto h_batch_info = pinned2_.GetSpan <PerNodeData<OpDataT>>(nidx.size ());
338
+ common::Span<PerNodeData<OpDataT>> h_batch_info =
339
+ pinned2_.GetSpan <PerNodeData<OpDataT>>(nidx.size ());
322
340
dh::TemporaryArray<PerNodeData<OpDataT>> d_batch_info (nidx.size ());
323
341
324
- std::size_t total_rows = 0 ;
325
- for (size_t i = 0 ; i < nidx.size (); i++) {
326
- h_batch_info[i] = {ridx_segments_.at (nidx.at (i)).segment , op_data.at (i)};
327
- total_rows += ridx_segments_.at (nidx.at (i)).segment .Size ();
342
+ for (std::size_t i = 0 ; i < nidx.size (); i++) {
343
+ h_batch_info[i] = {ridx_segments_.at (nidx[i]).segment , op_data[i]};
328
344
}
329
345
dh::safe_cuda (cudaMemcpyAsync (d_batch_info.data ().get (), h_batch_info.data (),
330
- h_batch_info.size () * sizeof (PerNodeData<OpDataT>),
331
- cudaMemcpyDefault, ctx->CUDACtx ()->Stream ()));
332
-
346
+ h_batch_info.size_bytes (), cudaMemcpyDefault,
347
+ ctx->CUDACtx ()->Stream ()));
333
348
// Temporary arrays
334
349
auto h_counts = pinned_.GetSpan <RowIndexT>(nidx.size ());
335
350
// Must initialize with 0 as 0 count is not written in the kernel.
336
351
dh::TemporaryArray<RowIndexT> d_counts (nidx.size (), 0 );
337
352
338
- // Partition the rows according to the operator
339
- SortPositionBatch<UpdatePositionOpT, OpDataT>(ctx, dh::ToSpan (d_batch_info), dh::ToSpan (ridx_),
340
- dh::ToSpan (ridx_tmp_), dh::ToSpan (d_counts),
341
- total_rows, op, &tmp_);
353
+ // Process a sub-batch
354
+ auto sub_batch_impl = [ctx, op, this ](common::Span<bst_node_t const > nidx,
355
+ common::Span<PerNodeData<OpDataT>> d_batch_info,
356
+ common::Span<RowIndexT> d_counts) {
357
+ std::size_t total_rows = 0 ;
358
+ for (bst_node_t i : nidx) {
359
+ total_rows += this ->ridx_segments_ [i].segment .Size ();
360
+ }
361
+
362
+ // Partition the rows according to the operator
363
+ SortPositionBatch<UpdatePositionOpT, OpDataT>(ctx, d_batch_info, dh::ToSpan (this ->ridx_ ),
364
+ dh::ToSpan (this ->ridx_tmp_ ), d_counts,
365
+ total_rows, op, &this ->tmp_ );
366
+ };
367
+
368
+ // Divide inputs into sub-batches.
369
+ for (std::size_t batch_begin = 0 , n = nidx.size (); batch_begin < n;
370
+ batch_begin += cuda_impl::kMaxUpdatePositionBatchSize ) {
371
+ auto constexpr kMax = static_cast <decltype (n)>(cuda_impl::kMaxUpdatePositionBatchSize );
372
+ auto batch_size = std::min (kMax , n - batch_begin);
373
+ auto nidx_batch = common::Span{nidx}.subspan (batch_begin, batch_size);
374
+ auto d_info_batch = dh::ToSpan (d_batch_info).subspan (batch_begin, batch_size);
375
+ auto d_counts_batch = dh::ToSpan (d_counts).subspan (batch_begin, batch_size);
376
+ sub_batch_impl (nidx_batch, d_info_batch, d_counts_batch);
377
+ }
378
+
342
379
dh::safe_cuda (cudaMemcpyAsync (h_counts.data (), d_counts.data ().get (), h_counts.size_bytes (),
343
380
cudaMemcpyDefault, ctx->CUDACtx ()->Stream ()));
344
381
// TODO(Rory): this synchronisation hurts performance a lot
0 commit comments