Skip to content

Commit 582ea10

Browse files
authored
[EM] Enable prediction cache for GPU. (dmlc#10707)
- Use `UpdatePosition` for all nodes and skip `FinalizePosition` when external memory is used. - Create `encode/decode` for node position, this is just as a refactor. - Reuse code between update position and finalization.
1 parent 0def8e0 commit 582ea10

20 files changed

+376
-325
lines changed

src/common/categorical.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
/**
2-
* Copyright 2020-2023, XGBoost Contributors
2+
* Copyright 2020-2024, XGBoost Contributors
33
* \file categorical.h
44
*/
55
#ifndef XGBOOST_COMMON_CATEGORICAL_H_
66
#define XGBOOST_COMMON_CATEGORICAL_H_
77

8-
#include <limits>
9-
108
#include "bitfield.h"
119
#include "xgboost/base.h"
1210
#include "xgboost/data.h"
1311
#include "xgboost/span.h"
12+
#include "xgboost/tree_model.h"
1413

15-
namespace xgboost {
16-
namespace common {
17-
14+
namespace xgboost::common {
1815
using CatBitField = LBitField32;
1916
using KCatBitField = CLBitField32;
2017

@@ -94,7 +91,12 @@ XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot
9491
struct IsCatOp {
9592
XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; }
9693
};
97-
} // namespace common
98-
} // namespace xgboost
94+
95+
inline auto GetNodeCats(common::Span<CatBitField::value_type const> categories,
96+
RegTree::CategoricalSplitMatrix::Segment seg) {
97+
KCatBitField node_cats{categories.subspan(seg.beg, seg.size)};
98+
return node_cats;
99+
}
100+
} // namespace xgboost::common
99101

100102
#endif // XGBOOST_COMMON_CATEGORICAL_H_

src/common/device_helpers.cuh

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,9 @@
1616
#include <cstddef> // for size_t
1717
#include <cub/cub.cuh>
1818
#include <cub/util_type.cuh> // for UnitWord
19-
#include <sstream>
20-
#include <string>
2119
#include <tuple>
2220
#include <vector>
2321

24-
#include "../collective/communicator-inl.h"
2522
#include "common.h"
2623
#include "device_vector.cuh"
2724
#include "xgboost/host_device_vector.h"
@@ -375,19 +372,24 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
375372
cudaMemcpyDeviceToHost));
376373
}
377374

378-
template <class HContainer, class DContainer>
379-
void CopyToD(HContainer const &h, DContainer *d) {
380-
if (h.empty()) {
381-
d->clear();
375+
template <class Src, class Dst>
376+
void CopyTo(Src const &src, Dst *dst) {
377+
if (src.empty()) {
378+
dst->clear();
382379
return;
383380
}
384-
d->resize(h.size());
385-
using HVT = std::remove_cv_t<typename HContainer::value_type>;
386-
using DVT = std::remove_cv_t<typename DContainer::value_type>;
387-
static_assert(std::is_same<HVT, DVT>::value,
381+
dst->resize(src.size());
382+
using SVT = std::remove_cv_t<typename Src::value_type>;
383+
using DVT = std::remove_cv_t<typename Dst::value_type>;
384+
static_assert(std::is_same<SVT, DVT>::value,
388385
"Host and device containers must have same value type.");
389-
dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT),
390-
cudaMemcpyHostToDevice));
386+
dh::safe_cuda(cudaMemcpyAsync(thrust::raw_pointer_cast(dst->data()), src.data(),
387+
src.size() * sizeof(SVT), cudaMemcpyDefault));
388+
}
389+
390+
template <class HContainer, class DContainer>
391+
void CopyToD(HContainer const &h, DContainer *d) {
392+
CopyTo(h, d);
391393
}
392394

393395
// Keep track of pinned memory allocation

src/common/device_vector.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ class DeviceUVector {
307307

308308
public:
309309
DeviceUVector() = default;
310+
explicit DeviceUVector(std::size_t n) { this->resize(n); }
310311
DeviceUVector(DeviceUVector const &that) = delete;
311312
DeviceUVector &operator=(DeviceUVector const &that) = delete;
312313
DeviceUVector(DeviceUVector &&that) = default;
@@ -330,7 +331,17 @@ class DeviceUVector {
330331
data_.resize(n, v);
331332
#endif
332333
}
334+
335+
void clear() { // NOLINT
336+
#if defined(XGBOOST_USE_RMM)
337+
this->data_.resize(0, rmm::cuda_stream_per_thread);
338+
#else
339+
this->data_.clear();
340+
#endif // defined(XGBOOST_USE_RMM)
341+
}
342+
333343
[[nodiscard]] std::size_t size() const { return data_.size(); } // NOLINT
344+
[[nodiscard]] bool empty() const { return this->size() == 0; } // NOLINT
334345

335346
[[nodiscard]] auto begin() { return data_.begin(); } // NOLINT
336347
[[nodiscard]] auto end() { return data_.end(); } // NOLINT

src/common/partition_builder.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "column_matrix.h"
2121
#include "xgboost/context.h"
2222
#include "xgboost/tree_model.h"
23+
#include "../tree/sample_position.h" // for SamplePosition
2324

2425
namespace xgboost::common {
2526
// The builder is required for samples partition to left and rights children for set of nodes
@@ -364,13 +365,14 @@ class PartitionBuilder {
364365
}
365366

366367
// Copy row partitions into global cache for reuse in objective
367-
template <typename Sampledp>
368+
template <typename Invalidp>
368369
void LeafPartition(Context const* ctx, RegTree const& tree, RowSetCollection const& row_set,
369-
std::vector<bst_node_t>* p_position, Sampledp sampledp) const {
370+
std::vector<bst_node_t>* p_position, Invalidp invalidp) const {
370371
auto& h_pos = *p_position;
371372
h_pos.resize(row_set.Data()->size(), std::numeric_limits<bst_node_t>::max());
372373

373374
auto p_begin = row_set.Data()->data();
375+
// For each node, walk through all the samples that fall in this node.
374376
ParallelFor(row_set.Size(), ctx->Threads(), [&](size_t i) {
375377
auto const& node = row_set[i];
376378
if (node.node_id < 0) {
@@ -381,7 +383,7 @@ class PartitionBuilder {
381383
size_t ptr_offset = node.end() - p_begin;
382384
CHECK_LE(ptr_offset, row_set.Data()->size()) << node.node_id;
383385
for (auto idx = node.begin(); idx != node.end(); ++idx) {
384-
h_pos[*idx] = sampledp(*idx) ? ~node.node_id : node.node_id;
386+
h_pos[*idx] = tree::SamplePosition::Encode(node.node_id, !invalidp(*idx));
385387
}
386388
}
387389
});

src/common/quantile.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "../collective/allgather.h"
1616
#include "../collective/allreduce.h"
17+
#include "../collective/communicator-inl.h" // for GetWorldSize, GetRank
1718
#include "categorical.h"
1819
#include "common.h"
1920
#include "device_helpers.cuh"

src/data/ellpack_page.cuh

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include <thrust/binary_search.h>
88

9+
#include <limits> // for numeric_limits
10+
911
#include "../common/categorical.h"
1012
#include "../common/compressed_iterator.h"
1113
#include "../common/device_helpers.cuh"
@@ -21,22 +23,26 @@ namespace xgboost {
2123
* Does not own underlying memory and may be trivially copied into kernels.
2224
*/
2325
struct EllpackDeviceAccessor {
24-
/*! \brief Whether or not if the matrix is dense. */
26+
/** @brief Whether or not if the matrix is dense. */
2527
bool is_dense;
26-
/*! \brief Row length for ELLPACK, equal to number of features. */
28+
/** @brief Row length for ELLPACK, equal to number of features when the data is dense. */
2729
bst_idx_t row_stride;
28-
bst_idx_t base_rowid{0};
29-
bst_idx_t n_rows{0};
30+
/** @brief Starting index of the rows. Used for external memory. */
31+
bst_idx_t base_rowid;
32+
/** @brief Number of rows in this batch. */
33+
bst_idx_t n_rows;
34+
/** @brief Acessor for the gradient index. */
3035
common::CompressedIterator<std::uint32_t> gidx_iter;
31-
/*! \brief Minimum value for each feature. Size equals to number of features. */
36+
/** @brief Minimum value for each feature. Size equals to number of features. */
3237
common::Span<const float> min_fvalue;
33-
/*! \brief Histogram cut pointers. Size equals to (number of features + 1). */
38+
/** @brief Histogram cut pointers. Size equals to (number of features + 1). */
3439
common::Span<const std::uint32_t> feature_segments;
35-
/*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */
40+
/** @brief Histogram cut values. Size equals to (bins per feature * number of features). */
3641
common::Span<const float> gidx_fvalue_map;
37-
42+
/** @brief Type of each feature, categorical or numerical. */
3843
common::Span<const FeatureType> feature_types;
3944

45+
EllpackDeviceAccessor() = delete;
4046
EllpackDeviceAccessor(DeviceOrd device, std::shared_ptr<const common::HistogramCuts> cuts,
4147
bool is_dense, size_t row_stride, size_t base_rowid, size_t n_rows,
4248
common::CompressedIterator<uint32_t> gidx_iter,
@@ -108,10 +114,10 @@ struct EllpackDeviceAccessor {
108114
return idx;
109115
}
110116

111-
[[nodiscard]] __device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
117+
[[nodiscard]] __device__ float GetFvalue(size_t ridx, size_t fidx) const {
112118
auto gidx = GetBinIndex(ridx, fidx);
113119
if (gidx == -1) {
114-
return nan("");
120+
return std::numeric_limits<float>::quiet_NaN();
115121
}
116122
return gidx_fvalue_map[gidx];
117123
}

src/objective/adaptive.cc

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
*/
44
#include "adaptive.h"
55

6-
#include <algorithm> // std::transform,std::find_if,std::copy,std::unique
7-
#include <cmath> // std::isnan
8-
#include <cstddef> // std::size_t
9-
#include <iterator> // std::distance
10-
#include <vector> // std::vector
6+
#include <algorithm> // std::transform,std::find_if,std::copy,std::unique
7+
#include <cmath> // std::isnan
8+
#include <cstddef> // std::size_t
9+
#include <iterator> // std::distance
10+
#include <vector> // std::vector
1111

1212
#include "../common/algorithm.h" // ArgSort
13-
#include "../common/common.h" // AssertGPUSupport
1413
#include "../common/numeric.h" // RunLengthEncode
1514
#include "../common/stats.h" // Quantile,WeightedQuantile
1615
#include "../common/threading_utils.h" // ParallelFor
1716
#include "../common/transform_iterator.h" // MakeIndexTransformIter
17+
#include "../tree/sample_position.h" // for SamplePosition
1818
#include "xgboost/base.h" // bst_node_t
1919
#include "xgboost/context.h" // Context
2020
#include "xgboost/data.h" // MetaInfo
@@ -23,6 +23,10 @@
2323
#include "xgboost/span.h" // Span
2424
#include "xgboost/tree_model.h" // RegTree
2525

26+
#if !defined(XGBOOST_USE_CUDA)
27+
#include "../common/common.h" // AssertGPUSupport
28+
#endif // !defined(XGBOOST_USE_CUDA)
29+
2630
namespace xgboost::obj::detail {
2731
void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
2832
std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr,
@@ -37,9 +41,10 @@ void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
3741
sorted_pos[i] = position[ridx[i]];
3842
}
3943
// find the first non-sampled row
40-
size_t begin_pos =
41-
std::distance(sorted_pos.cbegin(), std::find_if(sorted_pos.cbegin(), sorted_pos.cend(),
42-
[](bst_node_t nidx) { return nidx >= 0; }));
44+
size_t begin_pos = std::distance(
45+
sorted_pos.cbegin(),
46+
std::find_if(sorted_pos.cbegin(), sorted_pos.cend(),
47+
[](bst_node_t nidx) { return tree::SamplePosition::IsValid(nidx); }));
4348
CHECK_LE(begin_pos, sorted_pos.size());
4449

4550
std::vector<bst_node_t> leaf;

src/objective/adaptive.cu

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
*/
44
#include <thrust/sort.h>
55

6-
#include <cstdint> // std::int32_t
7-
#include <cub/cub.cuh> // NOLINT
6+
#include <cstdint> // std::int32_t
7+
#include <cub/cub.cuh> // NOLINT
88

99
#include "../collective/aggregator.h"
1010
#include "../common/cuda_context.cuh" // CUDAContext
1111
#include "../common/device_helpers.cuh"
1212
#include "../common/stats.cuh"
13+
#include "../tree/sample_position.h" // for SamplePosition
1314
#include "adaptive.h"
1415
#include "xgboost/context.h"
1516

@@ -30,10 +31,12 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
3031
// sort row index according to node index
3132
thrust::stable_sort_by_key(cuctx->TP(), sorted_position.begin(),
3233
sorted_position.begin() + n_samples, p_ridx->begin());
33-
size_t beg_pos =
34-
thrust::find_if(cuctx->CTP(), sorted_position.cbegin(), sorted_position.cend(),
35-
[] XGBOOST_DEVICE(bst_node_t nidx) { return nidx >= 0; }) -
36-
sorted_position.cbegin();
34+
// Find the first one that's not sampled (nidx not been negated).
35+
size_t beg_pos = thrust::find_if(cuctx->CTP(), sorted_position.cbegin(), sorted_position.cend(),
36+
[] XGBOOST_DEVICE(bst_node_t nidx) {
37+
return tree::SamplePosition::IsValid(nidx);
38+
}) -
39+
sorted_position.cbegin();
3740
if (beg_pos == sorted_position.size()) {
3841
auto& leaf = p_nidx->HostVector();
3942
tree.WalkTree([&](bst_node_t nidx) {

src/tree/gpu_hist/evaluate_splits.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
/**
22
* Copyright 2020-2024, XGBoost Contributors
33
*/
4-
#include <algorithm> // std::max
5-
#include <vector>
6-
#include <limits>
4+
#include <algorithm> // for :max
5+
#include <limits> // for numeric_limits
76

87
#include "../../collective/allgather.h"
8+
#include "../../collective/communicator-inl.h" // for GetWorldSize, GetRank
99
#include "../../common/categorical.h"
10-
#include "../../data/ellpack_page.cuh"
1110
#include "evaluate_splits.cuh"
1211
#include "expand_entry.cuh"
1312

src/tree/gpu_hist/row_partitioner.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ void RowPartitioner::Reset(Context const* ctx, bst_idx_t n_samples, bst_idx_t ba
1515
ridx_.resize(n_samples);
1616
ridx_tmp_.resize(n_samples);
1717
tmp_.clear();
18+
n_nodes_ = 1; // Root
1819

1920
CHECK_LE(n_samples, std::numeric_limits<cuda_impl::RowIndexT>::max());
2021
ridx_segments_.emplace_back(

0 commit comments

Comments
 (0)