Skip to content

Commit d414fdf

Browse files
authored
[EM] Add GPU version of the external memory QDM. (dmlc#10689)
1 parent 18b28d9 commit d414fdf

25 files changed

+536
-261
lines changed

include/xgboost/data.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ class DMatrix {
641641
typename XGDMatrixCallbackNext>
642642
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
643643
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
644-
std::int32_t nthread, bst_bin_t max_bin, std::string cache);
644+
std::int32_t nthread, bst_bin_t max_bin, std::string cache, bool on_host);
645645

646646
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
647647

src/common/device_helpers.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ inline int32_t CurrentDevice() {
116116
return device;
117117
}
118118

119+
// Helper function to get a device from a potentially CPU context.
120+
inline auto GetDevice(xgboost::Context const *ctx) {
121+
auto d = (ctx->IsCUDA()) ? ctx->Device() : xgboost::DeviceOrd::CUDA(dh::CurrentDevice());
122+
CHECK(!d.IsCPU());
123+
return d;
124+
}
125+
119126
inline size_t TotalMemory(int device_idx) {
120127
size_t device_free = 0;
121128
size_t device_total = 0;

src/data/data.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -914,9 +914,9 @@ template <typename DataIterHandle, typename DMatrixHandle, typename DataIterRese
914914
typename XGDMatrixCallbackNext>
915915
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
916916
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
917-
std::int32_t nthread, bst_bin_t max_bin, std::string cache) {
917+
std::int32_t nthread, bst_bin_t max_bin, std::string cache, bool on_host) {
918918
return new data::ExtMemQuantileDMatrix{
919-
iter, proxy, ref, reset, next, missing, nthread, std::move(cache), max_bin};
919+
iter, proxy, ref, reset, next, missing, nthread, std::move(cache), max_bin, on_host};
920920
}
921921

922922
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
@@ -935,7 +935,7 @@ template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCa
935935
template DMatrix*
936936
DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback, XGDMatrixCallbackNext>(
937937
DataIterHandle, DMatrixHandle, std::shared_ptr<DMatrix>, DataIterResetCallback*,
938-
XGDMatrixCallbackNext*, float, std::int32_t, bst_bin_t, std::string);
938+
XGDMatrixCallbackNext*, float, std::int32_t, bst_bin_t, std::string, bool);
939939

940940
template <typename AdapterT>
941941
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,

src/data/ellpack_page.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ bst_idx_t EllpackPage::Size() const {
4747
"EllpackPage is required";
4848
return impl_->Cuts();
4949
}
50+
51+
[[nodiscard]] bst_idx_t EllpackPage::BaseRowId() const {
52+
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
53+
"EllpackPage is required";
54+
return 0;
55+
}
56+
57+
[[nodiscard]] bool EllpackPage::IsDense() const {
58+
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
59+
"EllpackPage is required";
60+
return false;
61+
}
5062
} // namespace xgboost
5163

5264
#endif // XGBOOST_USE_CUDA

src/data/ellpack_page.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ void EllpackPage::SetBaseRowId(std::size_t row_id) { impl_->SetBaseRowId(row_id)
3939
return impl_->Cuts();
4040
}
4141

42+
[[nodiscard]] bst_idx_t EllpackPage::BaseRowId() const { return this->Impl()->base_rowid; }
43+
[[nodiscard]] bool EllpackPage::IsDense() const { return this->Impl()->IsDense(); }
44+
4245
// Bin each input data entry, store the bin indices in compressed form.
4346
__global__ void CompressBinEllpackKernel(
4447
common::CompressedBufferWriter wr,
@@ -397,7 +400,7 @@ struct CopyPage {
397400
size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bst_idx_t offset) {
398401
monitor_.Start(__func__);
399402
bst_idx_t num_elements = page->n_rows * page->row_stride;
400-
CHECK_EQ(row_stride, page->row_stride);
403+
CHECK_EQ(this->row_stride, page->row_stride);
401404
CHECK_EQ(NumSymbols(), page->NumSymbols());
402405
CHECK_GE(n_rows * row_stride, offset + num_elements);
403406
if (page == this) {

src/data/ellpack_page.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class EllpackPageImpl {
203203
[[nodiscard]] std::shared_ptr<common::HistogramCuts const> CutsShared() const { return cuts_; }
204204
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts) { cuts_ = cuts; }
205205

206+
[[nodiscard]] bool IsDense() const { return is_dense; }
206207
/** @return Estimation of memory cost of this page. */
207208
static size_t MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts&cuts) ;
208209

src/data/ellpack_page.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class EllpackPage {
4242

4343
/*! \return Number of instances in the page. */
4444
[[nodiscard]] bst_idx_t Size() const;
45+
[[nodiscard]] bool IsDense() const;
4546

4647
/*! \brief Set the base row id for this page. */
4748
void SetBaseRowId(std::size_t row_id);
@@ -50,6 +51,7 @@ class EllpackPage {
5051
EllpackPageImpl* Impl() { return impl_.get(); }
5152

5253
[[nodiscard]] common::HistogramCuts const& Cuts() const;
54+
[[nodiscard]] bst_idx_t BaseRowId() const;
5355

5456
private:
5557
std::unique_ptr<EllpackPageImpl> impl_;

src/data/ellpack_page_source.cu

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
#include "ellpack_page.cuh" // for EllpackPageImpl
1616
#include "ellpack_page.h" // for EllpackPage
1717
#include "ellpack_page_source.h"
18-
#include "xgboost/base.h" // for bst_idx_t
18+
#include "proxy_dmatrix.cuh" // for Dispatch
19+
#include "xgboost/base.h" // for bst_idx_t
1920

2021
namespace xgboost::data {
2122
struct EllpackHostCache {
@@ -182,4 +183,51 @@ template void
182183
EllpackPageSourceImpl<EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
183184
template void
184185
EllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
186+
187+
/**
188+
* ExtEllpackPageSourceImpl
189+
*/
190+
template <typename F>
191+
void ExtEllpackPageSourceImpl<F>::Fetch() {
192+
dh::safe_cuda(cudaSetDevice(this->Device().ordinal));
193+
if (!this->ReadCache()) {
194+
auto iter = this->source_->Iter();
195+
CHECK_EQ(this->count_, iter);
196+
++(*this->source_);
197+
CHECK_GE(this->source_->Iter(), 1);
198+
cuda_impl::Dispatch(proxy_, [this](auto const& value) {
199+
proxy_->Info().feature_types.SetDevice(dh::GetDevice(this->ctx_));
200+
auto d_feature_types = proxy_->Info().feature_types.ConstDeviceSpan();
201+
auto n_samples = value.NumRows();
202+
203+
dh::device_vector<size_t> row_counts(n_samples + 1, 0);
204+
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
205+
cuda_impl::Dispatch(proxy_, [=](auto const& value) {
206+
return GetRowCounts(value, row_counts_span, dh::GetDevice(this->ctx_), this->missing_);
207+
});
208+
209+
this->page_.reset(new EllpackPage{});
210+
*this->page_->Impl() = EllpackPageImpl{this->ctx_,
211+
value,
212+
this->missing_,
213+
this->info_->IsDense(),
214+
row_counts_span,
215+
d_feature_types,
216+
this->ext_info_.row_stride,
217+
n_samples,
218+
this->GetCuts()};
219+
this->info_->Extend(proxy_->Info(), false, true);
220+
});
221+
this->page_->SetBaseRowId(this->ext_info_.base_rows.at(iter));
222+
this->WriteCache();
223+
}
224+
}
225+
226+
// Instantiation
227+
template void
228+
ExtEllpackPageSourceImpl<DefaultFormatStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
229+
template void
230+
ExtEllpackPageSourceImpl<EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
231+
template void
232+
ExtEllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>::Fetch();
185233
} // namespace xgboost::data

src/data/ellpack_page_source.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <cstdint> // for int32_t
99
#include <memory> // for shared_ptr
1010
#include <utility> // for move
11+
#include <vector> // for vector
1112

1213
#include "../common/cuda_rt_utils.h" // for SupportsPageableMem
1314
#include "../common/hist_util.h" // for HistogramCuts
@@ -169,6 +170,51 @@ using EllpackPageHostSource =
169170
using EllpackPageSource =
170171
EllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
171172

173+
template <typename FormatCreatePolicy>
174+
class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin<EllpackPage, FormatCreatePolicy> {
175+
using Super = ExtQantileSourceMixin<EllpackPage, FormatCreatePolicy>;
176+
177+
Context const* ctx_;
178+
BatchParam p_;
179+
DMatrixProxy* proxy_;
180+
MetaInfo* info_;
181+
ExternalDataInfo ext_info_;
182+
183+
std::vector<bst_idx_t> base_rows_;
184+
185+
public:
186+
ExtEllpackPageSourceImpl(
187+
Context const* ctx, float missing, MetaInfo* info, ExternalDataInfo ext_info,
188+
std::shared_ptr<Cache> cache, BatchParam param, std::shared_ptr<common::HistogramCuts> cuts,
189+
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> source,
190+
DMatrixProxy* proxy, std::vector<bst_idx_t> base_rows)
191+
: Super{missing,
192+
ctx->Threads(),
193+
static_cast<bst_feature_t>(info->num_col_),
194+
ext_info.n_batches,
195+
source,
196+
cache},
197+
ctx_{ctx},
198+
p_{std::move(param)},
199+
proxy_{proxy},
200+
info_{info},
201+
ext_info_{std::move(ext_info)},
202+
base_rows_{std::move(base_rows)} {
203+
this->SetCuts(std::move(cuts), ctx->Device());
204+
this->Fetch();
205+
}
206+
207+
void Fetch() final;
208+
};
209+
210+
// Cache to host
211+
using ExtEllpackPageHostSource =
212+
ExtEllpackPageSourceImpl<EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
213+
214+
// Cache to disk
215+
using ExtEllpackPageSource =
216+
ExtEllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>;
217+
172218
#if !defined(XGBOOST_USE_CUDA)
173219
template <typename F>
174220
inline void EllpackPageSourceImpl<F>::Fetch() {
@@ -177,6 +223,11 @@ inline void EllpackPageSourceImpl<F>::Fetch() {
177223
(void)(is_dense_);
178224
common::AssertGPUSupport();
179225
}
226+
227+
template <typename F>
228+
inline void ExtEllpackPageSourceImpl<F>::Fetch() {
229+
common::AssertGPUSupport();
230+
}
180231
#endif // !defined(XGBOOST_USE_CUDA)
181232
} // namespace xgboost::data
182233

src/data/extmem_quantile_dmatrix.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ ExtMemQuantileDMatrix::ExtMemQuantileDMatrix(DataIterHandle iter_handle, DMatrix
2424
DataIterResetCallback *reset,
2525
XGDMatrixCallbackNext *next, float missing,
2626
std::int32_t n_threads, std::string cache,
27-
bst_bin_t max_bin)
28-
: cache_prefix_{std::move(cache)} {
27+
bst_bin_t max_bin, bool on_host)
28+
: cache_prefix_{std::move(cache)}, on_host_{on_host} {
2929
auto iter = std::make_shared<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>>(
3030
iter_handle, reset, next);
3131
iter->Reset();
@@ -72,13 +72,7 @@ void ExtMemQuantileDMatrix::InitFromCPU(
7272
common::HistogramCuts cuts;
7373
ExternalDataInfo ext_info;
7474
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
75-
76-
// From here on Info() has the correct data shape
77-
this->Info().num_row_ = ext_info.accumulated_rows;
78-
this->Info().num_col_ = ext_info.n_features;
79-
this->Info().num_nonzero_ = ext_info.nnz;
80-
this->Info().SynchronizeNumberOfColumns(ctx);
81-
ext_info.Validate();
75+
ext_info.SetInfo(ctx, &this->info_);
8276

8377
/**
8478
* Generate quantiles
@@ -110,7 +104,7 @@ void ExtMemQuantileDMatrix::InitFromCPU(
110104
CHECK_EQ(n_total_samples, ext_info.accumulated_rows);
111105
}
112106

113-
BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndexImpl() {
107+
[[nodiscard]] BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndexImpl() {
114108
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
115109
}
116110

@@ -148,5 +142,13 @@ BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,
148142
this->ellpack_page_source_);
149143
return batch_set;
150144
}
145+
146+
BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackPageImpl() {
147+
common::AssertGPUSupport();
148+
auto batch_set =
149+
std::visit([this](auto &&ptr) { return BatchSet{BatchIterator<EllpackPage>{ptr}}; },
150+
this->ellpack_page_source_);
151+
return batch_set;
152+
}
151153
#endif
152154
} // namespace xgboost::data

0 commit comments

Comments
 (0)