Skip to content

Commit bde1265

Browse files
authored
[EM] Return a full DMatrix instead of a Ellpack from the GPU sampler. (dmlc#10753)
1 parent d6ebcfb commit bde1265

20 files changed

+527
-216
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ java/xgboost4j-demo/data/
6363
java/xgboost4j-demo/tmp/
6464
java/xgboost4j-demo/model/
6565
nb-configuration*
66+
6667
# Eclipse
6768
.project
6869
.cproject
@@ -154,3 +155,6 @@ model*.json
154155
*.rds
155156
Rplots.pdf
156157
*.zip
158+
159+
# nsys
160+
*.nsys-rep

include/xgboost/data.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,15 @@ class MetaInfo {
110110
* @brief Validate all metainfo.
111111
*/
112112
void Validate(DeviceOrd device) const;
113-
114-
MetaInfo Slice(common::Span<int32_t const> ridxs) const;
113+
/**
114+
* @brief Slice the meta info.
115+
*
116+
* The device of ridxs is specified by the ctx object.
117+
*
118+
* @param ridxs Index of selected rows.
119+
* @param nnz The number of non-missing values.
120+
*/
121+
MetaInfo Slice(Context const* ctx, common::Span<bst_idx_t const> ridxs, bst_idx_t nnz) const;
115122

116123
MetaInfo Copy() const;
117124
/**

src/common/device_helpers.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,11 @@ xgboost::common::Span<T> ToSpan(DeviceUVector<T> &vec) {
508508
return {vec.data(), vec.size()};
509509
}
510510

511+
template <typename T>
512+
xgboost::common::Span<std::add_const_t<T>> ToSpan(DeviceUVector<T> const &vec) {
513+
return {vec.data(), vec.size()};
514+
}
515+
511516
// thrust begin, similiar to std::begin
512517
template <typename T>
513518
thrust::device_ptr<T> tbegin(xgboost::HostDeviceVector<T>& vector) { // NOLINT

src/common/linalg_op.cuh

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct IterOp {
7676
// returns a thrust iterator for a tensor view.
7777
template <typename T, std::int32_t kDim>
7878
auto tcbegin(TensorView<T, kDim> v) { // NOLINT
79-
return dh::MakeTransformIterator<T>(
79+
return thrust::make_transform_iterator(
8080
thrust::make_counting_iterator(0ul),
8181
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v});
8282
}
@@ -85,5 +85,16 @@ template <typename T, std::int32_t kDim>
8585
auto tcend(TensorView<T, kDim> v) { // NOLINT
8686
return tcbegin(v) + v.Size();
8787
}
88+
89+
template <typename T, std::int32_t kDim>
90+
auto tbegin(TensorView<T, kDim> v) { // NOLINT
91+
return thrust::make_transform_iterator(thrust::make_counting_iterator(0ul),
92+
detail::IterOp<std::remove_const_t<T>, kDim>{v});
93+
}
94+
95+
template <typename T, std::int32_t kDim>
96+
auto tend(TensorView<T, kDim> v) { // NOLINT
97+
return tbegin(v) + v.Size();
98+
}
8899
} // namespace xgboost::linalg
89100
#endif // XGBOOST_COMMON_LINALG_OP_CUH_

src/data/data.cc

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,10 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
351351
this->has_categorical_ = LoadFeatureType(feature_type_names, &feature_types.HostVector());
352352
}
353353

354+
namespace {
354355
template <typename T>
355-
std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, size_t stride = 1) {
356+
std::vector<T> Gather(const std::vector<T>& in, common::Span<bst_idx_t const> ridxs,
357+
size_t stride = 1) {
356358
if (in.empty()) {
357359
return {};
358360
}
@@ -361,16 +363,56 @@ std::vector<T> Gather(const std::vector<T> &in, common::Span<int const> ridxs, s
361363
for (auto i = 0ull; i < size; i++) {
362364
auto ridx = ridxs[i];
363365
for (size_t j = 0; j < stride; ++j) {
364-
out[i * stride +j] = in[ridx * stride + j];
366+
out[i * stride + j] = in[ridx * stride + j];
365367
}
366368
}
367369
return out;
368370
}
371+
} // namespace
372+
373+
namespace cuda_impl {
374+
void SliceMetaInfo(Context const* ctx, MetaInfo const& info, common::Span<bst_idx_t const> ridx,
375+
MetaInfo* p_out);
376+
#if !defined(XGBOOST_USE_CUDA)
377+
void SliceMetaInfo(Context const*, MetaInfo const&, common::Span<bst_idx_t const>, MetaInfo*) {
378+
common::AssertGPUSupport();
379+
}
380+
#endif
381+
} // namespace cuda_impl
369382

370-
MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
383+
MetaInfo MetaInfo::Slice(Context const* ctx, common::Span<bst_idx_t const> ridxs,
384+
bst_idx_t nnz) const {
385+
/**
386+
* Shape
387+
*/
371388
MetaInfo out;
372389
out.num_row_ = ridxs.size();
373390
out.num_col_ = this->num_col_;
391+
out.num_nonzero_ = nnz;
392+
393+
/**
394+
* Feature Info
395+
*/
396+
out.feature_weights.SetDevice(ctx->Device());
397+
out.feature_weights.Resize(this->feature_weights.Size());
398+
out.feature_weights.Copy(this->feature_weights);
399+
400+
out.feature_names = this->feature_names;
401+
402+
out.feature_types.SetDevice(ctx->Device());
403+
out.feature_types.Resize(this->feature_types.Size());
404+
out.feature_types.Copy(this->feature_types);
405+
406+
out.feature_type_names = this->feature_type_names;
407+
408+
/**
409+
* Sample Info
410+
*/
411+
if (ctx->IsCUDA()) {
412+
cuda_impl::SliceMetaInfo(ctx, *this, ridxs, &out);
413+
return out;
414+
}
415+
374416
// Groups is maintained by a higher level Python function. We should aim at deprecating
375417
// the slice function.
376418
if (this->labels.Size() != this->num_row_) {
@@ -386,13 +428,11 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
386428
});
387429
}
388430

389-
out.labels_upper_bound_.HostVector() =
390-
Gather(this->labels_upper_bound_.HostVector(), ridxs);
391-
out.labels_lower_bound_.HostVector() =
392-
Gather(this->labels_lower_bound_.HostVector(), ridxs);
431+
out.labels_upper_bound_.HostVector() = Gather(this->labels_upper_bound_.HostVector(), ridxs);
432+
out.labels_lower_bound_.HostVector() = Gather(this->labels_lower_bound_.HostVector(), ridxs);
393433
// weights
394434
if (this->weights_.Size() + 1 == this->group_ptr_.size()) {
395-
auto& h_weights = out.weights_.HostVector();
435+
auto& h_weights = out.weights_.HostVector();
396436
// Assuming all groups are available.
397437
out.weights_.HostVector() = h_weights;
398438
} else {
@@ -414,14 +454,6 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
414454
});
415455
}
416456

417-
out.feature_weights.Resize(this->feature_weights.Size());
418-
out.feature_weights.Copy(this->feature_weights);
419-
420-
out.feature_names = this->feature_names;
421-
out.feature_types.Resize(this->feature_types.Size());
422-
out.feature_types.Copy(this->feature_types);
423-
out.feature_type_names = this->feature_type_names;
424-
425457
return out;
426458
}
427459

src/data/data.cu

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
/**
2-
* Copyright 2019-2022 by XGBoost Contributors
2+
* Copyright 2019-2024, XGBoost Contributors
33
*
44
* \file data.cu
55
* \brief Handles setting metainfo from array interface.
66
*/
7+
#include <thrust/gather.h> // for gather
8+
79
#include "../common/cuda_context.cuh"
810
#include "../common/device_helpers.cuh"
911
#include "../common/linalg_op.cuh"
@@ -169,6 +171,62 @@ void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) {
169171
}
170172
}
171173

174+
namespace {
175+
void Gather(Context const* ctx, linalg::MatrixView<float const> in,
176+
common::Span<bst_idx_t const> ridx, linalg::Matrix<float>* p_out) {
177+
if (in.Empty()) {
178+
return;
179+
}
180+
auto& out = *p_out;
181+
out.Reshape(ridx.size(), in.Shape(1));
182+
auto d_out = out.View(ctx->Device());
183+
184+
auto cuctx = ctx->CUDACtx();
185+
auto map_it = thrust::make_transform_iterator(thrust::make_counting_iterator(0ull),
186+
[=] XGBOOST_DEVICE(bst_idx_t i) {
187+
auto [r, c] = linalg::UnravelIndex(i, in.Shape());
188+
return (ridx[r] * in.Shape(1)) + c;
189+
});
190+
CHECK_NE(in.Shape(1), 0);
191+
thrust::gather(cuctx->TP(), map_it, map_it + out.Size(), linalg::tcbegin(in),
192+
linalg::tbegin(d_out));
193+
}
194+
195+
template <typename T>
196+
void Gather(Context const* ctx, HostDeviceVector<T> const& in, common::Span<bst_idx_t const> ridx,
197+
HostDeviceVector<T>* p_out) {
198+
if (in.Empty()) {
199+
return;
200+
}
201+
in.SetDevice(ctx->Device());
202+
203+
auto& out = *p_out;
204+
out.SetDevice(ctx->Device());
205+
out.Resize(ridx.size());
206+
auto d_out = out.DeviceSpan();
207+
208+
auto cuctx = ctx->CUDACtx();
209+
auto d_in = in.ConstDeviceSpan();
210+
thrust::gather(cuctx->TP(), dh::tcbegin(ridx), dh::tcend(ridx), dh::tcbegin(d_in),
211+
dh::tbegin(d_out));
212+
}
213+
} // anonymous namespace
214+
215+
namespace cuda_impl {
216+
void SliceMetaInfo(Context const* ctx, MetaInfo const& info, common::Span<bst_idx_t const> ridx,
217+
MetaInfo* p_out) {
218+
auto& out = *p_out;
219+
220+
Gather(ctx, info.labels.View(ctx->Device()), ridx, &p_out->labels);
221+
Gather(ctx, info.base_margin_.View(ctx->Device()), ridx, &p_out->base_margin_);
222+
223+
Gather(ctx, info.labels_lower_bound_, ridx, &out.labels_lower_bound_);
224+
Gather(ctx, info.labels_upper_bound_, ridx, &out.labels_upper_bound_);
225+
226+
Gather(ctx, info.weights_, ridx, &out.weights_);
227+
}
228+
} // namespace cuda_impl
229+
172230
template <typename AdapterT>
173231
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
174232
const std::string& cache_prefix, DataSplitMode data_split_mode) {

src/data/ellpack_page.cu

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
/**
22
* Copyright 2019-2024, XGBoost contributors
33
*/
4+
#include <cuda/functional> // for proclaim_return_type
45
#include <thrust/iterator/discard_iterator.h>
56
#include <thrust/iterator/transform_output_iterator.h>
67

7-
#include <algorithm> // for copy
8-
#include <utility> // for move
9-
#include <vector> // for vector
8+
#include <algorithm> // for copy
9+
#include <utility> // for move
10+
#include <vector> // for vector
1011

1112
#include "../common/categorical.h"
1213
#include "../common/cuda_context.cuh"
@@ -576,4 +577,17 @@ EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
576577
common::CompressedIterator<uint32_t>(h_gidx_buffer->data(), NumSymbols()),
577578
feature_types};
578579
}
580+
581+
[[nodiscard]] bst_idx_t EllpackPageImpl::NumNonMissing(
582+
Context const* ctx, common::Span<FeatureType const> feature_types) const {
583+
auto d_acc = this->GetDeviceAccessor(ctx->Device(), feature_types);
584+
using T = typename decltype(d_acc.gidx_iter)::value_type;
585+
auto it = thrust::make_transform_iterator(
586+
thrust::make_counting_iterator(0ull),
587+
cuda::proclaim_return_type<T>([=] __device__(std::size_t i) { return d_acc.gidx_iter[i]; }));
588+
auto nnz = thrust::count_if(ctx->CUDACtx()->CTP(), it, it + d_acc.row_stride * d_acc.n_rows,
589+
cuda::proclaim_return_type<bool>(
590+
[=] __device__(T gidx) { return gidx != d_acc.NullValue(); }));
591+
return nnz;
592+
}
579593
} // namespace xgboost

src/data/ellpack_page.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ class EllpackPageImpl {
236236
[[nodiscard]] EllpackDeviceAccessor GetHostAccessor(
237237
Context const* ctx, std::vector<common::CompressedByteT>* h_gidx_buffer,
238238
common::Span<FeatureType const> feature_types = {}) const;
239+
/**
240+
* @brief Calculate the number of non-missing values.
241+
*/
242+
[[nodiscard]] bst_idx_t NumNonMissing(Context const* ctx,
243+
common::Span<FeatureType const> feature_types) const;
239244

240245
private:
241246
/**

src/data/iterative_dmatrix.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,17 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
101101
// Synchronise worker columns
102102
}
103103

104+
IterativeDMatrix::IterativeDMatrix(std::shared_ptr<EllpackPage> ellpack, MetaInfo const& info,
105+
BatchParam batch) {
106+
this->ellpack_ = ellpack;
107+
CHECK_EQ(this->Info().num_row_, 0);
108+
CHECK_EQ(this->Info().num_col_, 0);
109+
this->Info().Extend(info, true, true);
110+
this->Info().num_nonzero_ = info.num_nonzero_;
111+
CHECK_EQ(this->Info().num_row_, info.num_row_);
112+
this->batch_ = batch;
113+
}
114+
104115
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
105116
BatchParam const& param) {
106117
if (param.Initialized()) {

src/data/iterative_dmatrix.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ class IterativeDMatrix : public QuantileDMatrix {
4848
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
4949
XGDMatrixCallbackNext *next, float missing, int nthread,
5050
bst_bin_t max_bin);
51+
/**
52+
* @param Directly construct a QDM from an existing one.
53+
*/
54+
IterativeDMatrix(std::shared_ptr<EllpackPage> ellpack, MetaInfo const &info, BatchParam batch);
55+
5156
~IterativeDMatrix() override = default;
5257

5358
bool EllpackExists() const override { return static_cast<bool>(ellpack_); }

0 commit comments

Comments
 (0)