Skip to content

Commit 25966e4

Browse files
authored
[EM] Pass batch parameter into extmem format. (dmlc#10736)
- Allow customization for format reading. - Customize the number of pre-fetch batches.
1 parent 074cad2 commit 25966e4

15 files changed

+143
-102
lines changed

include/xgboost/data.h

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,42 +239,52 @@ struct Entry {
239239
};
240240

241241
/**
242-
* \brief Parameters for constructing histogram index batches.
242+
* @brief Parameters for constructing histogram index batches.
243243
*/
244244
struct BatchParam {
245245
/**
246-
* \brief Maximum number of bins per feature for histograms.
246+
* @brief Maximum number of bins per feature for histograms.
247247
*/
248248
bst_bin_t max_bin{0};
249249
/**
250-
* \brief Hessian, used for sketching with future approx implementation.
250+
* @brief Hessian, used for sketching with future approx implementation.
251251
*/
252252
common::Span<float const> hess;
253253
/**
254-
* \brief Whether should we force DMatrix to regenerate the batch. Only used for
254+
* @brief Whether should we force DMatrix to regenerate the batch. Only used for
255255
* GHistIndex.
256256
*/
257257
bool regen{false};
258258
/**
259-
* \brief Forbid regenerating the gradient index. Used for internal validation.
259+
* @brief Forbid regenerating the gradient index. Used for internal validation.
260260
*/
261261
bool forbid_regen{false};
262262
/**
263-
* \brief Parameter used to generate column matrix for hist.
263+
* @brief Parameter used to generate column matrix for hist.
264264
*/
265265
double sparse_thresh{std::numeric_limits<double>::quiet_NaN()};
266+
/**
267+
* @brief Used for GPU external memory. Whether to copy the data into device.
268+
*
269+
* This affects only the current round of iteration.
270+
*/
271+
bool prefetch_copy{true};
272+
/**
273+
* @brief The number of batches to pre-fetch for external memory.
274+
*/
275+
std::int32_t n_prefetch_batches{3};
266276

267277
/**
268-
* \brief Exact or others that don't need histogram.
278+
* @brief Exact or others that don't need histogram.
269279
*/
270280
BatchParam() = default;
271281
/**
272-
* \brief Used by the hist tree method.
282+
* @brief Used by the hist tree method.
273283
*/
274284
BatchParam(bst_bin_t max_bin, double sparse_thresh)
275285
: max_bin{max_bin}, sparse_thresh{sparse_thresh} {}
276286
/**
277-
* \brief Used by the approx tree method.
287+
* @brief Used by the approx tree method.
278288
*
279289
* Get batch with sketch weighted by hessian. The batch will be regenerated if the
280290
* span is changed, so caller should keep the span for each iteration.
@@ -295,7 +305,7 @@ struct BatchParam {
295305
}
296306
[[nodiscard]] bool Initialized() const { return max_bin != 0; }
297307
/**
298-
* \brief Make a copy of self for DMatrix to describe how its existing index was generated.
308+
* @brief Make a copy of self for DMatrix to describe how its existing index was generated.
299309
*/
300310
[[nodiscard]] BatchParam MakeCache() const {
301311
auto p = *this;

src/data/ellpack_page_raw_format.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ template <typename T>
6060
RET_IF_NOT(fi->Read(&impl->is_dense));
6161
RET_IF_NOT(fi->Read(&impl->row_stride));
6262

63-
if (has_hmm_ats_) {
63+
if (has_hmm_ats_ && !this->param_.prefetch_copy) {
6464
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
6565
} else {
6666
RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer));
@@ -95,7 +95,7 @@ template <typename T>
9595
CHECK(this->cuts_->cut_values_.DeviceCanRead());
9696
impl->SetCuts(this->cuts_);
9797

98-
fi->Read(page);
98+
fi->Read(page, this->param_.prefetch_copy);
9999
dh::DefaultStream().Sync();
100100

101101
return true;

src/data/ellpack_page_raw_format.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@ class EllpackHostCacheStream;
2626
class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
2727
std::shared_ptr<common::HistogramCuts const> cuts_;
2828
DeviceOrd device_;
29+
BatchParam param_;
2930
// Supports CUDA HMM or ATS
3031
bool has_hmm_ats_{false};
3132

3233
public:
3334
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
34-
bool has_hmm_ats)
35-
: cuts_{std::move(cuts)}, device_{device}, has_hmm_ats_{has_hmm_ats} {}
35+
BatchParam param, bool has_hmm_ats)
36+
: cuts_{std::move(cuts)},
37+
device_{device},
38+
param_{std::move(param)},
39+
has_hmm_ats_{has_hmm_ats} {}
3640
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
3741
[[nodiscard]] std::size_t Write(const EllpackPage& page,
3842
common::AlignedFileWriteStream* fo) override;

src/data/ellpack_page_source.cu

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,13 @@
1111

1212
#include "../common/common.h" // for safe_cuda
1313
#include "../common/ref_resource_view.cuh"
14-
#include "../common/cuda_pinned_allocator.h" // for pinned_allocator
1514
#include "../common/device_helpers.cuh" // for CUDAStreamView, DefaultStream
1615
#include "../common/resource.cuh" // for PrivateCudaMmapConstStream
1716
#include "ellpack_page.cuh" // for EllpackPageImpl
1817
#include "ellpack_page.h" // for EllpackPage
1918
#include "ellpack_page_source.h"
2019
#include "proxy_dmatrix.cuh" // for Dispatch
2120
#include "xgboost/base.h" // for bst_idx_t
22-
#include "../common/cuda_rt_utils.h" // for NvtxScopedRange
2321
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
2422

2523
namespace xgboost::data {
@@ -91,14 +89,20 @@ class EllpackHostCacheStreamImpl {
9189
ptr_ += 1;
9290
}
9391

94-
void Read(EllpackPage* out) const {
92+
void Read(EllpackPage* out, bool prefetch_copy) const {
9593
auto page = this->cache_->Get(ptr_);
9694

9795
auto impl = out->Impl();
98-
impl->gidx_buffer =
99-
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(page->gidx_buffer.size());
100-
dh::safe_cuda(cudaMemcpyAsync(impl->gidx_buffer.data(), page->gidx_buffer.data(),
101-
page->gidx_buffer.size_bytes(), cudaMemcpyDefault));
96+
if (prefetch_copy) {
97+
impl->gidx_buffer =
98+
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(page->gidx_buffer.size());
99+
dh::safe_cuda(cudaMemcpyAsync(impl->gidx_buffer.data(), page->gidx_buffer.data(),
100+
page->gidx_buffer.size_bytes(), cudaMemcpyDefault));
101+
} else {
102+
auto res = page->gidx_buffer.Resource();
103+
impl->gidx_buffer = common::RefResourceView<common::CompressedByteT>{
104+
res->DataAs<common::CompressedByteT>(), page->gidx_buffer.size(), res};
105+
}
102106

103107
impl->n_rows = page->Size();
104108
impl->is_dense = page->IsDense();
@@ -120,7 +124,9 @@ std::shared_ptr<EllpackHostCache> EllpackHostCacheStream::Share() { return p_imp
120124

121125
void EllpackHostCacheStream::Seek(bst_idx_t offset_bytes) { this->p_impl_->Seek(offset_bytes); }
122126

123-
void EllpackHostCacheStream::Read(EllpackPage* page) const { this->p_impl_->Read(page); }
127+
void EllpackHostCacheStream::Read(EllpackPage* page, bool prefetch_copy) const {
128+
this->p_impl_->Read(page, prefetch_copy);
129+
}
124130

125131
void EllpackHostCacheStream::Write(EllpackPage const& page) { this->p_impl_->Write(page); }
126132

@@ -162,8 +168,9 @@ EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateWriter(StringV
162168

163169
template std::unique_ptr<
164170
typename EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::ReaderT>
165-
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(
166-
StringView name, std::uint64_t offset, std::uint64_t length) const;
171+
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy>::CreateReader(StringView name,
172+
bst_idx_t offset,
173+
bst_idx_t length) const;
167174

168175
/**
169176
* EllpackMmapStreamPolicy
@@ -233,6 +240,7 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
233240
++(*this->source_);
234241
CHECK_GE(this->source_->Iter(), 1);
235242
cuda_impl::Dispatch(proxy_, [this](auto const& value) {
243+
CHECK(this->proxy_->Ctx()->IsCUDA()) << "All batches must use the same device type.";
236244
proxy_->Info().feature_types.SetDevice(dh::GetDevice(this->ctx_));
237245
auto d_feature_types = proxy_->Info().feature_types.ConstDeviceSpan();
238246
auto n_samples = value.NumRows();

src/data/ellpack_page_source.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class EllpackHostCacheStream {
5353

5454
void Seek(bst_idx_t offset_bytes);
5555

56-
void Read(EllpackPage* page) const;
56+
void Read(EllpackPage* page, bool prefetch_copy) const;
5757
void Write(EllpackPage const& page);
5858
};
5959

@@ -71,9 +71,9 @@ class EllpackFormatPolicy {
7171
// For testing with the HMM flag.
7272
explicit EllpackFormatPolicy(bool has_hmm) : has_hmm_{has_hmm} {}
7373

74-
[[nodiscard]] auto CreatePageFormat() const {
74+
[[nodiscard]] auto CreatePageFormat(BatchParam const& param) const {
7575
CHECK_EQ(cuts_->cut_values_.Device(), device_);
76-
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_, has_hmm_}};
76+
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_, param, has_hmm_}};
7777
return fmt;
7878
}
7979

src/data/extmem_quantile_dmatrix.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ void ExtMemQuantileDMatrix::InitFromCPU(
6666
Context const *ctx,
6767
std::shared_ptr<DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>> iter,
6868
DMatrixHandle proxy_handle, BatchParam const &p, float missing, std::shared_ptr<DMatrix> ref) {
69+
xgboost_NVTX_FN_RANGE();
70+
6971
auto proxy = MakeProxy(proxy_handle);
7072
CHECK(proxy);
7173

@@ -118,7 +120,7 @@ BatchSet<GHistIndexMatrix> ExtMemQuantileDMatrix::GetGradientIndex(Context const
118120
}
119121

120122
CHECK(this->ghist_index_source_);
121-
this->ghist_index_source_->Reset();
123+
this->ghist_index_source_->Reset(param);
122124

123125
if (!std::isnan(param.sparse_thresh) &&
124126
param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) {

src/data/extmem_quantile_dmatrix.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "proxy_dmatrix.h" // for DataIterProxy
1212
#include "xgboost/context.h" // for Context
1313
#include "xgboost/data.h" // for BatchParam
14+
#include "../common/cuda_rt_utils.h"
1415

1516
namespace xgboost::data {
1617
void ExtMemQuantileDMatrix::InitFromCUDA(
@@ -78,9 +79,9 @@ BatchSet<EllpackPage> ExtMemQuantileDMatrix::GetEllpackBatches(Context const *,
7879
}
7980

8081
std::visit(
81-
[this](auto &&ptr) {
82+
[this, param](auto &&ptr) {
8283
CHECK(ptr);
83-
ptr->Reset();
84+
ptr->Reset(param);
8485
},
8586
this->ellpack_page_source_);
8687

src/data/gradient_index_page_source.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ void ExtGradientIndexPageSource::Fetch() {
3737
CHECK_GE(source_->Iter(), 1);
3838
CHECK_NE(cuts_.Values().size(), 0);
3939
HostAdapterDispatch(proxy_, [this](auto const& value) {
40+
CHECK(this->proxy_->Ctx()->IsCPU()) << "All batches must use the same device type.";
4041
// This does three things:
4142
// - Generate CSR matrix for gradient index.
4243
// - Generate the column matrix for gradient index.

src/data/gradient_index_page_source.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class GHistIndexFormatPolicy {
3131
using FormatT = SparsePageFormat<GHistIndexMatrix>;
3232

3333
public:
34-
[[nodiscard]] auto CreatePageFormat() const {
34+
[[nodiscard]] auto CreatePageFormat(BatchParam const&) const {
3535
std::unique_ptr<FormatT> fmt{new GHistIndexRawFormat{cuts_}};
3636
return fmt;
3737
}

src/data/sparse_page_dmatrix.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void SparsePageDMatrix::InitializeSparsePage(Context const *ctx) {
8282
// release the iterator and data.
8383
if (cache_info_.at(id)->written) {
8484
CHECK(sparse_page_source_);
85-
sparse_page_source_->Reset();
85+
sparse_page_source_->Reset({});
8686
return;
8787
}
8888

@@ -114,7 +114,7 @@ BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches(Context const *ctx) {
114114
std::make_shared<CSCPageSource>(this->missing_, ctx->Threads(), this->Info().num_col_,
115115
this->n_batches_, cache_info_.at(id), sparse_page_source_);
116116
} else {
117-
column_source_->Reset();
117+
column_source_->Reset({});
118118
}
119119
return BatchSet{BatchIterator<CSCPage>{this->column_source_}};
120120
}
@@ -129,7 +129,7 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches(Context const
129129
this->missing_, ctx->Threads(), this->Info().num_col_, this->n_batches_, cache_info_.at(id),
130130
sparse_page_source_);
131131
} else {
132-
sorted_column_source_->Reset();
132+
sorted_column_source_->Reset({});
133133
}
134134
return BatchSet{BatchIterator<SortedCSCPage>{this->sorted_column_source_}};
135135
}
@@ -161,7 +161,7 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(Context const *ct
161161
param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
162162
} else {
163163
CHECK(ghist_index_source_);
164-
ghist_index_source_->Reset();
164+
ghist_index_source_->Reset(param);
165165
}
166166
return BatchSet{BatchIterator<GHistIndexMatrix>{this->ghist_index_source_}};
167167
}

0 commit comments

Comments
 (0)