Skip to content

Commit 4486040

Browse files
authored
[backport][EM] Optimization for deep trees. (dmlc#11387) (dmlc#11444)
- Decouple the row partition batch size from the driver batch size. This will allow us to process more nodes for each data batch. - Pick a heuristic to use ATS instead of data copy to handle cases where we have a large number of small nodes. - Make sure a new page that happens to be the last is placed on the host.
1 parent 198d4a0 commit 4486040

File tree

3 files changed

+111
-52
lines changed

3 files changed

+111
-52
lines changed

src/data/ellpack_page_source.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,12 @@ class EllpackHostCacheStreamImpl {
190190
auto& new_impl = this->cache_->pages.back();
191191
auto offset = new_impl->Copy(&ctx, impl, this->cache_->offsets.back());
192192
this->cache_->offsets.back() += offset;
193-
// No need to copy if it's already in device.
194-
if (last_page && !this->cache_->on_device.back()) {
195-
auto commited = commit_host_page(this->cache_->pages.back().get());
196-
this->cache_->pages.back() = std::move(commited);
197-
}
193+
}
194+
195+
// No need to copy if it's already in device.
196+
if (last_page && !this->cache_->on_device.back()) {
197+
auto commited = commit_host_page(this->cache_->pages.back().get());
198+
this->cache_->pages.back() = std::move(commited);
198199
}
199200

200201
return new_page;

src/tree/gpu_hist/row_partitioner.cuh

Lines changed: 79 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
/**
2-
* Copyright 2017-2024, XGBoost contributors
2+
* Copyright 2017-2025, XGBoost contributors
33
*/
44
#pragma once
5-
#include <thrust/execution_policy.h>
65
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
76
#include <thrust/iterator/transform_output_iterator.h> // for make_transform_output_iterator
87

9-
#include <algorithm> // for max
10-
#include <cstddef> // for size_t
11-
#include <cstdint> // for int32_t, uint32_t
12-
#include <vector> // for vector
8+
#include <algorithm> // for max
9+
#include <cstddef> // for size_t
10+
#include <cstdint> // for int32_t, uint32_t
11+
#include <cuda/functional> // for proclaim_return_type
12+
#include <vector> // for vector
1313

1414
#include "../../common/cuda_context.cuh" // for CUDAContext
1515
#include "../../common/device_helpers.cuh" // for MakeTransformIterator
@@ -21,7 +21,7 @@ namespace xgboost::tree {
2121
namespace cuda_impl {
2222
using RowIndexT = std::uint32_t;
2323
// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
24-
static const std::int32_t kMaxUpdatePositionBatchSize = 32;
24+
inline constexpr std::int32_t kMaxUpdatePositionBatchSize = 32;
2525
} // namespace cuda_impl
2626

2727
/**
@@ -37,7 +37,7 @@ struct Segment {
3737
Segment(cuda_impl::RowIndexT begin, cuda_impl::RowIndexT end) : begin(begin), end(end) {
3838
CHECK_GE(end, begin);
3939
}
40-
__host__ __device__ bst_idx_t Size() const { return end - begin; }
40+
[[nodiscard]] XGBOOST_DEVICE bst_idx_t Size() const { return end - begin; }
4141
};
4242

4343
template <typename OpDataT>
@@ -46,28 +46,42 @@ struct PerNodeData {
4646
OpDataT data;
4747
};
4848

49-
template <typename BatchIterT>
50-
XGBOOST_DEV_INLINE void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
51-
int* batch_idx, std::size_t* item_idx) {
49+
/**
50+
* @param global_thread_idx In practice, the row index within the total number of rows for
51+
* this node batch.
52+
* @param batch_idx The nidx within this node batch (not the actual node index in a tree).
53+
* @param item_idx The resulting global row index (without accounting for base_rowid). This maps the
54+
* row index within the node batch back to the global row index.
55+
*/
56+
template <typename T>
57+
XGBOOST_DEV_INLINE void AssignBatch(dh::LDGIterator<T> const& batch_info_iter,
58+
std::size_t global_thread_idx, int* batch_idx,
59+
std::size_t* item_idx) {
5260
cuda_impl::RowIndexT sum = 0;
53-
for (int i = 0; i < cuda_impl::kMaxUpdatePositionBatchSize; i++) {
54-
if (sum + batch_info[i].segment.Size() > global_thread_idx) {
61+
// Search for the nidx in batch and the corresponding global row index, exit once found.
62+
for (std::int32_t i = 0; i < cuda_impl::kMaxUpdatePositionBatchSize; i++) {
63+
if (sum + batch_info_iter[i].segment.Size() > global_thread_idx) {
5564
*batch_idx = i;
56-
*item_idx = (global_thread_idx - sum) + batch_info[i].segment.begin;
65+
// the beginning of the segment plus the offset into that segment
66+
*item_idx = (global_thread_idx - sum) + batch_info_iter[i].segment.begin;
5767
break;
5868
}
59-
sum += batch_info[i].segment.Size();
69+
sum += batch_info_iter[i].segment.Size();
6070
}
6171
}
6272

73+
/**
74+
* @param total_rows The total number of rows for this batch of nodes.
75+
*/
6376
template <int kBlockSize, typename OpDataT>
6477
__global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel(
65-
dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<cuda_impl::RowIndexT> d_ridx,
66-
const common::Span<const cuda_impl::RowIndexT> ridx_tmp, bst_idx_t total_rows) {
78+
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_iter,
79+
common::Span<cuda_impl::RowIndexT> d_ridx,
80+
common::Span<cuda_impl::RowIndexT const> const ridx_tmp, bst_idx_t total_rows) {
6781
for (auto idx : dh::GridStrideRange<std::size_t>(0, total_rows)) {
68-
int batch_idx;
69-
std::size_t item_idx;
70-
AssignBatch(batch_info, idx, &batch_idx, &item_idx);
82+
std::int32_t batch_idx; // unused
83+
std::size_t item_idx = std::numeric_limits<std::size_t>::max();
84+
AssignBatch(batch_info_iter, idx, &batch_idx, &item_idx);
7185
d_ridx[item_idx] = ridx_tmp[item_idx];
7286
}
7387
}
@@ -141,18 +155,22 @@ void SortPositionBatch(Context const* ctx, common::Span<const PerNodeData<OpData
141155
auto discard_write_iterator =
142156
thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
143157
auto counting = thrust::make_counting_iterator(0llu);
144-
auto input_iterator =
145-
dh::MakeTransformIterator<IndexFlagTuple>(counting, [=] __device__(std::size_t idx) {
146-
int nidx_in_batch;
158+
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
159+
counting, cuda::proclaim_return_type<IndexFlagTuple>([=] __device__(std::size_t idx) {
160+
std::int32_t nidx_in_batch;
147161
std::size_t item_idx;
148162
AssignBatch(batch_info_itr, idx, &nidx_in_batch, &item_idx);
149163
auto go_left = op(ridx[item_idx], nidx_in_batch, batch_info_itr[nidx_in_batch].data);
150164
return IndexFlagTuple{static_cast<cuda_impl::RowIndexT>(item_idx), go_left, nidx_in_batch,
151165
go_left};
152-
});
153-
// Avoid using int as the offset type
166+
}));
167+
// Reach down to the dispatch function to avoid using int as the offset type.
154168
std::size_t n_bytes = 0;
155169
if (tmp->empty()) {
170+
// The size of temporary storage is calculated based on the total number of
171+
// rows. Since the root node has all the rows, subsequence allocatioin must be smaller
172+
// than the root node. As a result, we can calculate this once and reuse it throughout
173+
// the iteration.
156174
auto ret =
157175
cub::DispatchScan<decltype(input_iterator), decltype(discard_write_iterator), IndexFlagOp,
158176
cub::NullType, std::int64_t>::Dispatch(nullptr, n_bytes, input_iterator,
@@ -305,10 +323,10 @@ class RowPartitioner {
305323
* second. Returns true if this training instance goes on the left partition.
306324
*/
307325
template <typename UpdatePositionOpT, typename OpDataT>
308-
void UpdatePositionBatch(Context const* ctx, const std::vector<bst_node_t>& nidx,
309-
const std::vector<bst_node_t>& left_nidx,
310-
const std::vector<bst_node_t>& right_nidx,
311-
const std::vector<OpDataT>& op_data, UpdatePositionOpT op) {
326+
void UpdatePositionBatch(Context const* ctx, std::vector<bst_node_t> const& nidx,
327+
std::vector<bst_node_t> const& left_nidx,
328+
std::vector<bst_node_t> const& right_nidx,
329+
std::vector<OpDataT> const& op_data, UpdatePositionOpT op) {
312330
if (nidx.empty()) {
313331
return;
314332
}
@@ -317,28 +335,47 @@ class RowPartitioner {
317335
CHECK_EQ(nidx.size(), right_nidx.size());
318336
CHECK_EQ(nidx.size(), op_data.size());
319337
this->n_nodes_ += (left_nidx.size() + right_nidx.size());
320-
321-
auto h_batch_info = pinned2_.GetSpan<PerNodeData<OpDataT>>(nidx.size());
338+
common::Span<PerNodeData<OpDataT>> h_batch_info =
339+
pinned2_.GetSpan<PerNodeData<OpDataT>>(nidx.size());
322340
dh::TemporaryArray<PerNodeData<OpDataT>> d_batch_info(nidx.size());
323341

324-
std::size_t total_rows = 0;
325-
for (size_t i = 0; i < nidx.size(); i++) {
326-
h_batch_info[i] = {ridx_segments_.at(nidx.at(i)).segment, op_data.at(i)};
327-
total_rows += ridx_segments_.at(nidx.at(i)).segment.Size();
342+
for (std::size_t i = 0; i < nidx.size(); i++) {
343+
h_batch_info[i] = {ridx_segments_.at(nidx[i]).segment, op_data[i]};
328344
}
329345
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
330-
h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
331-
cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
332-
346+
h_batch_info.size_bytes(), cudaMemcpyDefault,
347+
ctx->CUDACtx()->Stream()));
333348
// Temporary arrays
334349
auto h_counts = pinned_.GetSpan<RowIndexT>(nidx.size());
335350
// Must initialize with 0 as 0 count is not written in the kernel.
336351
dh::TemporaryArray<RowIndexT> d_counts(nidx.size(), 0);
337352

338-
// Partition the rows according to the operator
339-
SortPositionBatch<UpdatePositionOpT, OpDataT>(ctx, dh::ToSpan(d_batch_info), dh::ToSpan(ridx_),
340-
dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
341-
total_rows, op, &tmp_);
353+
// Process a sub-batch
354+
auto sub_batch_impl = [ctx, op, this](common::Span<bst_node_t const> nidx,
355+
common::Span<PerNodeData<OpDataT>> d_batch_info,
356+
common::Span<RowIndexT> d_counts) {
357+
std::size_t total_rows = 0;
358+
for (bst_node_t i : nidx) {
359+
total_rows += this->ridx_segments_[i].segment.Size();
360+
}
361+
362+
// Partition the rows according to the operator
363+
SortPositionBatch<UpdatePositionOpT, OpDataT>(ctx, d_batch_info, dh::ToSpan(this->ridx_),
364+
dh::ToSpan(this->ridx_tmp_), d_counts,
365+
total_rows, op, &this->tmp_);
366+
};
367+
368+
// Divide inputs into sub-batches.
369+
for (std::size_t batch_begin = 0, n = nidx.size(); batch_begin < n;
370+
batch_begin += cuda_impl::kMaxUpdatePositionBatchSize) {
371+
auto constexpr kMax = static_cast<decltype(n)>(cuda_impl::kMaxUpdatePositionBatchSize);
372+
auto batch_size = std::min(kMax, n - batch_begin);
373+
auto nidx_batch = common::Span{nidx}.subspan(batch_begin, batch_size);
374+
auto d_info_batch = dh::ToSpan(d_batch_info).subspan(batch_begin, batch_size);
375+
auto d_counts_batch = dh::ToSpan(d_counts).subspan(batch_begin, batch_size);
376+
sub_batch_impl(nidx_batch, d_info_batch, d_counts_batch);
377+
}
378+
342379
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
343380
cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
344381
// TODO(Rory): this synchronisation hurts performance a lot

src/tree/updater_gpu_hist.cu

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2024, XGBoost contributors
2+
* Copyright 2017-2025, XGBoost contributors
33
*/
44
#include <thrust/functional.h> // for plus
55
#include <thrust/transform.h> // for transform
@@ -12,7 +12,6 @@
1212
#include <vector> // for vector
1313

1414
#include "../collective/aggregator.h"
15-
#include "../collective/broadcast.h" // for Broadcast
1615
#include "../common/categorical.h" // for KCatBitField
1716
#include "../common/cuda_context.cuh" // for CUDAContext
1817
#include "../common/cuda_rt_utils.h" // for CheckComputeCapability
@@ -53,6 +52,12 @@ using cuda_impl::ApproxBatch;
5352
using cuda_impl::HistBatch;
5453
using xgboost::cuda_impl::StaticBatch;
5554

55+
namespace {
56+
// Use a large number to handle external memory with deep trees.
57+
inline constexpr std::size_t kMaxNodeBatchSize = 1024;
58+
inline constexpr std::size_t kNeedCopyThreshold = 4;
59+
} // anonymous namespace
60+
5661
// Extra data for each node that is passed to the update position function
5762
struct NodeSplitData {
5863
RegTree::Node split_node;
@@ -452,6 +457,23 @@ struct GPUHistMakerDevice {
452457
}
453458
};
454459

460+
// Heuristic to avoid copying the data batch.
461+
[[nodiscard]] bool NeedCopy(DMatrix* p_fmat,
462+
std::vector<GPUExpandEntry> const& candidates) const {
463+
if (p_fmat->SingleColBlock()) {
464+
return true; // use default if it's in-core
465+
}
466+
bst_idx_t n_total_samples = p_fmat->Info().num_row_;
467+
bst_idx_t n_samples = 0;
468+
for (auto const& c : candidates) {
469+
for (auto const& part : this->partitioners_) {
470+
n_samples += part->GetRows(c.nid).size();
471+
}
472+
}
473+
// avoid copy if the kernel is small.
474+
return n_samples * kNeedCopyThreshold > n_total_samples;
475+
}
476+
455477
// Update position and build histogram.
456478
void PartitionAndBuildHist(DMatrix* p_fmat, std::vector<GPUExpandEntry> const& expand_set,
457479
std::vector<GPUExpandEntry> const& candidates, RegTree const* p_tree) {
@@ -474,7 +496,7 @@ struct GPUHistMakerDevice {
474496
std::vector<bst_node_t> build_nidx(candidates.size());
475497
std::vector<bst_node_t> subtraction_nidx(candidates.size());
476498
AssignNodes(p_tree, this->quantiser.get(), candidates, build_nidx, subtraction_nidx);
477-
auto prefetch_copy = !build_nidx.empty();
499+
auto prefetch_copy = !build_nidx.empty() && this->NeedCopy(p_fmat, candidates);
478500

479501
this->histogram_.AllocateHistograms(ctx_, build_nidx, subtraction_nidx);
480502

@@ -711,8 +733,7 @@ struct GPUHistMakerDevice {
711733

712734
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat, ObjInfo const* task,
713735
RegTree* p_tree, HostDeviceVector<bst_node_t>* p_out_position) {
714-
// Process maximum 32 nodes at a time
715-
Driver<GPUExpandEntry> driver(param, 32);
736+
Driver<GPUExpandEntry> driver{param, kMaxNodeBatchSize};
716737

717738
p_fmat = this->Reset(gpair_all, p_fmat);
718739
driver.Push({this->InitRoot(p_fmat, p_tree)});

0 commit comments

Comments
 (0)