11/* *
2- * Copyright 2017-2024 , XGBoost contributors
2+ * Copyright 2017-2025 , XGBoost contributors
33 */
44#pragma once
5- #include < thrust/execution_policy.h>
65#include < thrust/iterator/counting_iterator.h> // for make_counting_iterator
76#include < thrust/iterator/transform_output_iterator.h> // for make_transform_output_iterator
87
9- #include < algorithm> // for max
10- #include < cstddef> // for size_t
11- #include < cstdint> // for int32_t, uint32_t
12- #include < vector> // for vector
8+ #include < algorithm> // for max
9+ #include < cstddef> // for size_t
10+ #include < cstdint> // for int32_t, uint32_t
11+ #include < cuda/functional> // for proclaim_return_type
12+ #include < vector> // for vector
1313
1414#include " ../../common/cuda_context.cuh" // for CUDAContext
1515#include " ../../common/device_helpers.cuh" // for MakeTransformIterator
@@ -21,7 +21,7 @@ namespace xgboost::tree {
2121namespace cuda_impl {
2222using RowIndexT = std::uint32_t ;
2323// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
24- static const std::int32_t kMaxUpdatePositionBatchSize = 32 ;
24+ inline constexpr std::int32_t kMaxUpdatePositionBatchSize = 32 ;
2525} // namespace cuda_impl
2626
2727/* *
@@ -37,7 +37,7 @@ struct Segment {
3737 Segment (cuda_impl::RowIndexT begin, cuda_impl::RowIndexT end) : begin(begin), end(end) {
3838 CHECK_GE (end, begin);
3939 }
40- __host__ __device__ bst_idx_t Size () const { return end - begin; }
40+ [[nodiscard]] XGBOOST_DEVICE bst_idx_t Size () const { return end - begin; }
4141};
4242
4343template <typename OpDataT>
@@ -46,28 +46,42 @@ struct PerNodeData {
4646 OpDataT data;
4747};
4848
49- template <typename BatchIterT>
50- XGBOOST_DEV_INLINE void AssignBatch (BatchIterT batch_info, std::size_t global_thread_idx,
51- int * batch_idx, std::size_t * item_idx) {
49+ /* *
50+ * @param global_thread_idx In practice, the row index within the total number of rows for
51+ * this node batch.
52+ * @param batch_idx The nidx within this node batch (not the actual node index in a tree).
53+ * @param item_idx The resulting global row index (without accounting for base_rowid). This maps the
54+ * row index within the node batch back to the global row index.
55+ */
56+ template <typename T>
57+ XGBOOST_DEV_INLINE void AssignBatch (dh::LDGIterator<T> const & batch_info_iter,
58+ std::size_t global_thread_idx, int * batch_idx,
59+ std::size_t * item_idx) {
5260 cuda_impl::RowIndexT sum = 0 ;
53- for (int i = 0 ; i < cuda_impl::kMaxUpdatePositionBatchSize ; i++) {
54- if (sum + batch_info[i].segment .Size () > global_thread_idx) {
61+ // Search for the nidx in batch and the corresponding global row index, exit once found.
62+ for (std::int32_t i = 0 ; i < cuda_impl::kMaxUpdatePositionBatchSize ; i++) {
63+ if (sum + batch_info_iter[i].segment .Size () > global_thread_idx) {
5564 *batch_idx = i;
56- *item_idx = (global_thread_idx - sum) + batch_info[i].segment .begin ;
65+ // the beginning of the segment plus the offset into that segment
66+ *item_idx = (global_thread_idx - sum) + batch_info_iter[i].segment .begin ;
5767 break ;
5868 }
59- sum += batch_info [i].segment .Size ();
69+ sum += batch_info_iter [i].segment .Size ();
6070 }
6171}
6272
73+ /* *
74+ * @param total_rows The total number of rows for this batch of nodes.
75+ */
6376template <int kBlockSize , typename OpDataT>
6477__global__ __launch_bounds__ (kBlockSize ) void SortPositionCopyKernel(
65- dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<cuda_impl::RowIndexT> d_ridx,
66- const common::Span<const cuda_impl::RowIndexT> ridx_tmp, bst_idx_t total_rows) {
78+ dh::LDGIterator<PerNodeData<OpDataT>> batch_info_iter,
79+ common::Span<cuda_impl::RowIndexT> d_ridx,
80+ common::Span<cuda_impl::RowIndexT const > const ridx_tmp, bst_idx_t total_rows) {
6781 for (auto idx : dh::GridStrideRange<std::size_t >(0 , total_rows)) {
68- int batch_idx;
69- std::size_t item_idx;
70- AssignBatch (batch_info , idx, &batch_idx, &item_idx);
82+ std:: int32_t batch_idx; // unused
83+ std::size_t item_idx = std::numeric_limits<std:: size_t >:: max () ;
84+ AssignBatch (batch_info_iter , idx, &batch_idx, &item_idx);
7185 d_ridx[item_idx] = ridx_tmp[item_idx];
7286 }
7387}
@@ -141,18 +155,22 @@ void SortPositionBatch(Context const* ctx, common::Span<const PerNodeData<OpData
141155 auto discard_write_iterator =
142156 thrust::make_transform_output_iterator (dh::TypedDiscard<IndexFlagTuple>(), write_results);
143157 auto counting = thrust::make_counting_iterator (0llu);
144- auto input_iterator =
145- dh::MakeTransformIterator <IndexFlagTuple>(counting, [=] __device__ (std::size_t idx) {
146- int nidx_in_batch;
158+ auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
159+ counting, cuda::proclaim_return_type <IndexFlagTuple>([=] __device__ (std::size_t idx) {
160+ std:: int32_t nidx_in_batch;
147161 std::size_t item_idx;
148162 AssignBatch (batch_info_itr, idx, &nidx_in_batch, &item_idx);
149163 auto go_left = op (ridx[item_idx], nidx_in_batch, batch_info_itr[nidx_in_batch].data );
150164 return IndexFlagTuple{static_cast <cuda_impl::RowIndexT>(item_idx), go_left, nidx_in_batch,
151165 go_left};
152- });
153- // Avoid using int as the offset type
166+ })) ;
167+ // Reach down to the dispatch function to avoid using int as the offset type.
154168 std::size_t n_bytes = 0 ;
155169 if (tmp->empty ()) {
170+ // The size of temporary storage is calculated based on the total number of
171+ // rows. Since the root node has all the rows, subsequence allocatioin must be smaller
172+ // than the root node. As a result, we can calculate this once and reuse it throughout
173+ // the iteration.
156174 auto ret =
157175 cub::DispatchScan<decltype (input_iterator), decltype (discard_write_iterator), IndexFlagOp,
158176 cub::NullType, std::int64_t >::Dispatch (nullptr , n_bytes, input_iterator,
@@ -305,10 +323,10 @@ class RowPartitioner {
305323 * second. Returns true if this training instance goes on the left partition.
306324 */
307325 template <typename UpdatePositionOpT, typename OpDataT>
308- void UpdatePositionBatch (Context const * ctx, const std::vector<bst_node_t >& nidx,
309- const std::vector<bst_node_t >& left_nidx,
310- const std::vector<bst_node_t >& right_nidx,
311- const std::vector<OpDataT>& op_data, UpdatePositionOpT op) {
326+ void UpdatePositionBatch (Context const * ctx, std::vector<bst_node_t > const & nidx,
327+ std::vector<bst_node_t > const & left_nidx,
328+ std::vector<bst_node_t > const & right_nidx,
329+ std::vector<OpDataT> const & op_data, UpdatePositionOpT op) {
312330 if (nidx.empty ()) {
313331 return ;
314332 }
@@ -317,28 +335,47 @@ class RowPartitioner {
317335 CHECK_EQ (nidx.size (), right_nidx.size ());
318336 CHECK_EQ (nidx.size (), op_data.size ());
319337 this ->n_nodes_ += (left_nidx.size () + right_nidx.size ());
320-
321- auto h_batch_info = pinned2_.GetSpan <PerNodeData<OpDataT>>(nidx.size ());
338+ common::Span<PerNodeData<OpDataT>> h_batch_info =
339+ pinned2_.GetSpan <PerNodeData<OpDataT>>(nidx.size ());
322340 dh::TemporaryArray<PerNodeData<OpDataT>> d_batch_info (nidx.size ());
323341
324- std::size_t total_rows = 0 ;
325- for (size_t i = 0 ; i < nidx.size (); i++) {
326- h_batch_info[i] = {ridx_segments_.at (nidx.at (i)).segment , op_data.at (i)};
327- total_rows += ridx_segments_.at (nidx.at (i)).segment .Size ();
342+ for (std::size_t i = 0 ; i < nidx.size (); i++) {
343+ h_batch_info[i] = {ridx_segments_.at (nidx[i]).segment , op_data[i]};
328344 }
329345 dh::safe_cuda (cudaMemcpyAsync (d_batch_info.data ().get (), h_batch_info.data (),
330- h_batch_info.size () * sizeof (PerNodeData<OpDataT>),
331- cudaMemcpyDefault, ctx->CUDACtx ()->Stream ()));
332-
346+ h_batch_info.size_bytes (), cudaMemcpyDefault,
347+ ctx->CUDACtx ()->Stream ()));
333348 // Temporary arrays
334349 auto h_counts = pinned_.GetSpan <RowIndexT>(nidx.size ());
335350 // Must initialize with 0 as 0 count is not written in the kernel.
336351 dh::TemporaryArray<RowIndexT> d_counts (nidx.size (), 0 );
337352
338- // Partition the rows according to the operator
339- SortPositionBatch<UpdatePositionOpT, OpDataT>(ctx, dh::ToSpan (d_batch_info), dh::ToSpan (ridx_),
340- dh::ToSpan (ridx_tmp_), dh::ToSpan (d_counts),
341- total_rows, op, &tmp_);
353+ // Process a sub-batch
354+ auto sub_batch_impl = [ctx, op, this ](common::Span<bst_node_t const > nidx,
355+ common::Span<PerNodeData<OpDataT>> d_batch_info,
356+ common::Span<RowIndexT> d_counts) {
357+ std::size_t total_rows = 0 ;
358+ for (bst_node_t i : nidx) {
359+ total_rows += this ->ridx_segments_ [i].segment .Size ();
360+ }
361+
362+ // Partition the rows according to the operator
363+ SortPositionBatch<UpdatePositionOpT, OpDataT>(ctx, d_batch_info, dh::ToSpan (this ->ridx_ ),
364+ dh::ToSpan (this ->ridx_tmp_ ), d_counts,
365+ total_rows, op, &this ->tmp_ );
366+ };
367+
368+ // Divide inputs into sub-batches.
369+ for (std::size_t batch_begin = 0 , n = nidx.size (); batch_begin < n;
370+ batch_begin += cuda_impl::kMaxUpdatePositionBatchSize ) {
371+ auto constexpr kMax = static_cast <decltype (n)>(cuda_impl::kMaxUpdatePositionBatchSize );
372+ auto batch_size = std::min (kMax , n - batch_begin);
373+ auto nidx_batch = common::Span{nidx}.subspan (batch_begin, batch_size);
374+ auto d_info_batch = dh::ToSpan (d_batch_info).subspan (batch_begin, batch_size);
375+ auto d_counts_batch = dh::ToSpan (d_counts).subspan (batch_begin, batch_size);
376+ sub_batch_impl (nidx_batch, d_info_batch, d_counts_batch);
377+ }
378+
342379 dh::safe_cuda (cudaMemcpyAsync (h_counts.data (), d_counts.data ().get (), h_counts.size_bytes (),
343380 cudaMemcpyDefault, ctx->CUDACtx ()->Stream ()));
344381 // TODO(Rory): this synchronisation hurts performance a lot
0 commit comments