Skip to content

Commit 10f5f6d

Browse files
authored
[backport] Fix page concatenation for validation dataset. (dmlc#11338) (dmlc#11435)
1 parent ab1e531 commit 10f5f6d

File tree

3 files changed

+93
-16
lines changed

3 files changed

+93
-16
lines changed

src/data/ellpack_page_source.cu

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,29 @@ class EllpackHostCacheStreamImpl {
104104

105105
this->cache_->sizes_orig.push_back(page.Impl()->MemCostBytes());
106106
auto orig_ptr = this->cache_->sizes_orig.size() - 1;
107+
CHECK_EQ(this->cache_->pages.size(), this->cache_->on_device.size());
107108

108109
CHECK_LT(orig_ptr, this->cache_->NumBatchesOrig());
109110
auto cache_idx = this->cache_->cache_mapping.at(orig_ptr);
110111
// Wrap up the previous page if this is a new page, or this is the last page.
111112
auto new_page = cache_idx == this->cache_->pages.size();
112-
113+
// Last page expected from the user.
113114
auto last_page = (orig_ptr + 1) == this->cache_->NumBatchesOrig();
114-
// No page concatenation is performed. If there's page concatenation, then the number
115-
// of pages in the cache must be smaller than the input number of pages.
116-
bool no_concat = this->cache_->NumBatchesOrig() == this->cache_->buffer_rows.size();
115+
116+
bool const no_concat = this->cache_->NoConcat();
117+
117118
// Whether the page should be cached in device. If true, then we don't need to make a
118119
// copy during write since the temporary page is already in device when page
119120
// concatenation is enabled.
120-
bool to_device = this->cache_->prefer_device &&
121-
this->cache_->NumDevicePages() < this->cache_->max_num_device_pages;
122-
123-
auto commit_page = [&ctx](EllpackPageImpl const* old_impl) {
121+
//
122+
// This applies only to a new cached page. If we are concatenating this page to an
123+
// existing cached page, then we should respect the existing flag obtained from the
124+
// first page of the cached page.
125+
bool to_device_if_new_page =
126+
this->cache_->prefer_device &&
127+
this->cache_->NumDevicePages() < this->cache_->max_num_device_pages;
128+
129+
auto commit_host_page = [](EllpackPageImpl const* old_impl) {
124130
CHECK_EQ(old_impl->gidx_buffer.Resource()->Type(), common::ResourceHandler::kCudaMalloc);
125131
auto new_impl = std::make_unique<EllpackPageImpl>();
126132
new_impl->CopyInfo(old_impl);
@@ -137,7 +143,7 @@ class EllpackHostCacheStreamImpl {
137143
auto new_impl = std::make_unique<EllpackPageImpl>();
138144
new_impl->CopyInfo(page.Impl());
139145

140-
if (to_device) {
146+
if (to_device_if_new_page) {
141147
// Copy to device
142148
new_impl->gidx_buffer = common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(
143149
page.Impl()->gidx_buffer.size());
@@ -151,15 +157,16 @@ class EllpackHostCacheStreamImpl {
151157

152158
this->cache_->offsets.push_back(new_impl->n_rows * new_impl->info.row_stride);
153159
this->cache_->pages.push_back(std::move(new_impl));
160+
this->cache_->on_device.push_back(to_device_if_new_page);
154161
return new_page;
155162
}
156163

157164
if (new_page) {
158165
// No need to copy if it's already in device.
159-
if (!this->cache_->pages.empty() && !to_device) {
166+
if (!this->cache_->pages.empty() && !this->cache_->on_device.back()) {
160167
// Need to wrap up the previous page.
161-
auto commited = commit_page(this->cache_->pages.back().get());
162-
// Replace the previous page with a new page.
168+
auto commited = commit_host_page(this->cache_->pages.back().get());
169+
// Replace the previous page (on device) with a new page on host.
163170
this->cache_->pages.back() = std::move(commited);
164171
}
165172
// Push a new page
@@ -174,16 +181,18 @@ class EllpackHostCacheStreamImpl {
174181
auto offset = new_impl->Copy(&ctx, impl, 0);
175182

176183
this->cache_->offsets.push_back(offset);
184+
177185
this->cache_->pages.push_back(std::move(new_impl));
186+
this->cache_->on_device.push_back(to_device_if_new_page);
178187
} else {
179188
CHECK(!this->cache_->pages.empty());
180189
CHECK_EQ(cache_idx, this->cache_->pages.size() - 1);
181190
auto& new_impl = this->cache_->pages.back();
182191
auto offset = new_impl->Copy(&ctx, impl, this->cache_->offsets.back());
183192
this->cache_->offsets.back() += offset;
184193
// No need to copy if it's already in device.
185-
if (last_page && !to_device) {
186-
auto commited = commit_page(this->cache_->pages.back().get());
194+
if (last_page && !this->cache_->on_device.back()) {
195+
auto commited = commit_host_page(this->cache_->pages.back().get());
187196
this->cache_->pages.back() = std::move(commited);
188197
}
189198
}

src/data/ellpack_page_source.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2024, XGBoost Contributors
2+
* Copyright 2019-2025, XGBoost Contributors
33
*/
44

55
#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
@@ -47,6 +47,7 @@ struct EllpackCacheInfo {
4747
// This is a memory-based cache. It can be a mixed of the device memory and the host memory.
4848
struct EllpackMemCache {
4949
std::vector<std::unique_ptr<EllpackPageImpl>> pages;
50+
std::vector<bool> on_device;
5051
std::vector<std::size_t> offsets;
5152
// Size of each batch before concatenation.
5253
std::vector<bst_idx_t> sizes_orig;
@@ -65,6 +66,9 @@ struct EllpackMemCache {
6566
[[nodiscard]] std::size_t SizeBytes() const;
6667

6768
[[nodiscard]] bool Empty() const { return this->SizeBytes() == 0; }
69+
// No page concatenation is performed. If there's page concatenation, then the number of
70+
// pages in the cache must be smaller than the input number of pages.
71+
[[nodiscard]] bool NoConcat() const { return this->NumBatchesOrig() == this->buffer_rows.size(); }
6872

6973
[[nodiscard]] bst_idx_t NumBatchesOrig() const { return cache_mapping.size(); }
7074
[[nodiscard]] EllpackPageImpl const* At(std::int32_t k) const;
@@ -187,6 +191,7 @@ class EllpackCacheStreamPolicy : public F<S> {
187191

188192
[[nodiscard]] std::unique_ptr<ReaderT> CreateReader(StringView name, bst_idx_t offset,
189193
bst_idx_t length) const;
194+
std::shared_ptr<EllpackMemCache const> Share() const { return p_cache_; }
190195
};
191196

192197
template <typename S, template <typename> typename F>

tests/cpp/data/test_ellpack_page_raw_format.cu

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
/**
2-
* Copyright 2021-2024, XGBoost contributors
2+
* Copyright 2021-2025, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
55
#include <xgboost/data.h>
66

7+
#include <numeric> // for partial_sum
8+
79
#include "../../../src/data/ellpack_page.cuh" // for EllpackPage, GetRowStride
810
#include "../../../src/data/ellpack_page_raw_format.h" // for EllpackPageRawFormat
911
#include "../../../src/data/ellpack_page_source.h" // for EllpackFormatStreamPolicy
@@ -157,4 +159,65 @@ TEST_P(TestEllpackPageRawFormat, HostIO) {
157159
}
158160

159161
INSTANTIATE_TEST_SUITE_P(EllpackPageRawFormat, TestEllpackPageRawFormat, ::testing::Bool());
162+
163+
TEST(EllpackPageRawFormat, DevicePageConcat) {
164+
auto ctx = MakeCUDACtx(0);
165+
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
166+
bst_idx_t n_features = 16, n_samples = 128;
167+
168+
auto test = [&](std::int32_t max_num_device_pages, std::int64_t min_cache_page_bytes) {
169+
EllpackCacheInfo cinfo{param, true, max_num_device_pages,
170+
std::numeric_limits<float>::quiet_NaN()};
171+
ExternalDataInfo ext_info;
172+
173+
ext_info.n_batches = 8;
174+
ext_info.row_stride = n_features;
175+
for (bst_idx_t i = 0; i < ext_info.n_batches; ++i) {
176+
ext_info.base_rowids.push_back(n_samples);
177+
}
178+
std::partial_sum(ext_info.base_rowids.cbegin(), ext_info.base_rowids.cend(),
179+
ext_info.base_rowids.begin());
180+
ext_info.accumulated_rows = n_samples * ext_info.n_batches;
181+
ext_info.nnz = ext_info.accumulated_rows * n_features;
182+
183+
auto p_fmat = RandomDataGenerator{n_samples, n_features, 0}.Seed(0).GenerateDMatrix();
184+
EllpackCacheStreamPolicy<EllpackPage, EllpackFormatPolicy> policy;
185+
186+
for (auto const &page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
187+
auto cuts = page.Impl()->CutsShared();
188+
CalcCacheMapping(&ctx, true, cuts, min_cache_page_bytes, ext_info, &cinfo);
189+
[&] {
190+
ASSERT_EQ(cinfo.buffer_rows.size(), 4ul);
191+
}();
192+
policy.SetCuts(page.Impl()->CutsShared(), ctx.Device(), std::move(cinfo));
193+
}
194+
195+
auto format = policy.CreatePageFormat(param);
196+
197+
// write multipe pages
198+
for (bst_idx_t i = 0; i < ext_info.n_batches; ++i) {
199+
for (auto const &page : p_fmat->GetBatches<EllpackPage>(&ctx, param)) {
200+
auto writer = policy.CreateWriter({}, i);
201+
[[maybe_unused]] auto n_bytes = format->Write(page, writer.get());
202+
}
203+
}
204+
// check correct concatenation.
205+
auto mem_cache = policy.Share();
206+
return mem_cache;
207+
};
208+
209+
{
210+
auto mem_cache = test(1, n_features * n_samples);
211+
ASSERT_EQ(mem_cache->on_device.size(), 4);
212+
ASSERT_TRUE(mem_cache->on_device[0]);
213+
ASSERT_EQ(mem_cache->NumDevicePages(), 1);
214+
}
215+
{
216+
auto mem_cache = test(2, n_features * n_samples);
217+
ASSERT_EQ(mem_cache->on_device.size(), 4);
218+
ASSERT_TRUE(mem_cache->on_device[0]);
219+
ASSERT_TRUE(mem_cache->on_device[1]);
220+
ASSERT_EQ(mem_cache->NumDevicePages(), 2);
221+
}
222+
}
160223
} // namespace xgboost::data

0 commit comments

Comments
 (0)