Skip to content

Commit 55aef8f

Browse files
authored
[EM] Avoid resizing host cache. (dmlc#10734)
* [EM] Avoid resizing host cache. - Add SAM allocator and resource. - Use page-based cache instead of stream-based cache.
1 parent dbfafd8 commit 55aef8f

16 files changed

+262
-144
lines changed

jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class DataIteratorProxy {
132132
bool cache_on_host_{true}; // TODO(Bobby): Make this optional.
133133

134134
template <typename T>
135-
using Alloc = xgboost::common::cuda_impl::pinned_allocator<T>;
135+
using Alloc = xgboost::common::cuda_impl::PinnedAllocator<T>;
136136
template <typename U>
137137
using HostVector = std::vector<U, Alloc<U>>;
138138

src/common/cuda_pinned_allocator.h

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ namespace xgboost::common::cuda_impl {
2121
// that Thrust used to provide.
2222
//
2323
// \see https://en.cppreference.com/w/cpp/memory/allocator
24-
2524
template <typename T>
2625
struct PinnedAllocPolicy {
2726
using pointer = T*; // NOLINT: The type returned by address() / allocate()
@@ -33,7 +32,7 @@ struct PinnedAllocPolicy {
3332
return std::numeric_limits<size_type>::max() / sizeof(value_type);
3433
}
3534

36-
pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT
35+
[[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT
3736
if (cnt > this->max_size()) {
3837
throw std::bad_alloc{};
3938
} // end if
@@ -57,7 +56,7 @@ struct ManagedAllocPolicy {
5756
return std::numeric_limits<size_type>::max() / sizeof(value_type);
5857
}
5958

60-
pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT
59+
[[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT
6160
if (cnt > this->max_size()) {
6261
throw std::bad_alloc{};
6362
} // end if
@@ -70,16 +69,49 @@ struct ManagedAllocPolicy {
7069
void deallocate(pointer p, size_type) { dh::safe_cuda(cudaFree(p)); } // NOLINT
7170
};
7271

72+
// This is actually a pinned memory allocator in disguise. We utilize HMM or ATS for
73+
// efficient tracked memory allocation.
74+
template <typename T>
75+
struct SamAllocPolicy {
76+
using pointer = T*; // NOLINT: The type returned by address() / allocate()
77+
using const_pointer = const T*; // NOLINT: The type returned by address()
78+
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
79+
using value_type = T; // NOLINT: The type of the elements in the allocator
80+
81+
size_type max_size() const { // NOLINT
82+
return std::numeric_limits<size_type>::max() / sizeof(value_type);
83+
}
84+
85+
[[nodiscard]] pointer allocate(size_type cnt, const_pointer = nullptr) const { // NOLINT
86+
if (cnt > this->max_size()) {
87+
throw std::bad_alloc{};
88+
} // end if
89+
90+
size_type n_bytes = cnt * sizeof(value_type);
91+
pointer result = reinterpret_cast<pointer>(std::malloc(n_bytes));
92+
if (!result) {
93+
throw std::bad_alloc{};
94+
}
95+
dh::safe_cuda(cudaHostRegister(result, n_bytes, cudaHostRegisterDefault));
96+
return result;
97+
}
98+
99+
void deallocate(pointer p, size_type) { // NOLINT
100+
dh::safe_cuda(cudaHostUnregister(p));
101+
std::free(p);
102+
}
103+
};
104+
73105
template <typename T, template <typename> typename Policy>
74-
class CudaHostAllocatorImpl : public Policy<T> { // NOLINT
106+
class CudaHostAllocatorImpl : public Policy<T> {
75107
public:
76-
using value_type = typename Policy<T>::value_type; // NOLINT
77-
using pointer = typename Policy<T>::pointer; // NOLINT
78-
using const_pointer = typename Policy<T>::const_pointer; // NOLINT
79-
using size_type = typename Policy<T>::size_type; // NOLINT
108+
using typename Policy<T>::value_type;
109+
using typename Policy<T>::pointer;
110+
using typename Policy<T>::const_pointer;
111+
using typename Policy<T>::size_type;
80112

81-
using reference = T&; // NOLINT: The parameter type for address()
82-
using const_reference = const T&; // NOLINT: The parameter type for address()
113+
using reference = value_type&; // NOLINT: The parameter type for address()
114+
using const_reference = const value_type&; // NOLINT: The parameter type for address()
83115

84116
using difference_type = std::ptrdiff_t; // NOLINT: The type of the distance between two pointers
85117

@@ -101,14 +133,17 @@ class CudaHostAllocatorImpl : public Policy<T> { // NOLINT
101133
pointer address(reference r) { return &r; } // NOLINT
102134
const_pointer address(const_reference r) { return &r; } // NOLINT
103135

104-
bool operator==(CudaHostAllocatorImpl const& x) const { return true; }
136+
bool operator==(CudaHostAllocatorImpl const&) const { return true; }
105137

106138
bool operator!=(CudaHostAllocatorImpl const& x) const { return !operator==(x); }
107139
};
108140

109141
template <typename T>
110-
using pinned_allocator = CudaHostAllocatorImpl<T, PinnedAllocPolicy>; // NOLINT
142+
using PinnedAllocator = CudaHostAllocatorImpl<T, PinnedAllocPolicy>; // NOLINT
143+
144+
template <typename T>
145+
using ManagedAllocator = CudaHostAllocatorImpl<T, ManagedAllocPolicy>; // NOLINT
111146

112147
template <typename T>
113-
using managed_allocator = CudaHostAllocatorImpl<T, ManagedAllocPolicy>; // NOLINT
148+
using SamAllocator = CudaHostAllocatorImpl<T, SamAllocPolicy>;
114149
} // namespace xgboost::common::cuda_impl

src/common/io.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ class ResourceHandler {
286286
kMmap = 1,
287287
kCudaMalloc = 2,
288288
kCudaMmap = 3,
289+
kCudaHostCache = 4,
289290
};
290291

291292
private:
@@ -310,6 +311,8 @@ class ResourceHandler {
310311
return "CudaMalloc";
311312
case kCudaMmap:
312313
return "CudaMmap";
314+
case kCudaHostCache:
315+
return "CudaHostCache";
313316
}
314317
LOG(FATAL) << "Unreachable.";
315318
return {};

src/common/ref_resource_view.cuh

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ namespace xgboost::common {
1616
* @brief Make a fixed size `RefResourceView` with cudaMalloc resource.
1717
*/
1818
template <typename T>
19-
[[nodiscard]] RefResourceView<T> MakeFixedVecWithCudaMalloc(Context const*,
20-
std::size_t n_elements) {
19+
[[nodiscard]] RefResourceView<T> MakeFixedVecWithCudaMalloc(std::size_t n_elements) {
2120
auto resource = std::make_shared<common::CudaMallocResource>(n_elements * sizeof(T));
2221
auto ref = RefResourceView{resource->DataAs<T>(), n_elements, resource};
2322
return ref;
@@ -26,8 +25,15 @@ template <typename T>
2625
template <typename T>
2726
[[nodiscard]] RefResourceView<T> MakeFixedVecWithCudaMalloc(Context const* ctx,
2827
std::size_t n_elements, T const& init) {
29-
auto ref = MakeFixedVecWithCudaMalloc<T>(ctx, n_elements);
28+
auto ref = MakeFixedVecWithCudaMalloc<T>(n_elements);
3029
thrust::fill_n(ctx->CUDACtx()->CTP(), ref.data(), ref.size(), init);
3130
return ref;
3231
}
32+
33+
template <typename T>
34+
[[nodiscard]] RefResourceView<T> MakeFixedVecWithPinnedMalloc(std::size_t n_elements) {
35+
auto resource = std::make_shared<common::CudaPinnedResource>(n_elements * sizeof(T));
36+
auto ref = RefResourceView{resource->DataAs<T>(), n_elements, resource};
37+
return ref;
38+
}
3339
} // namespace xgboost::common

src/common/resource.cuh

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
#include <cstddef> // for size_t
66
#include <functional> // for function
77

8-
#include "device_vector.cuh" // for DeviceUVector
9-
#include "io.h" // for ResourceHandler, MMAPFile
10-
#include "xgboost/string_view.h" // for StringView
8+
#include "cuda_pinned_allocator.h" // for SamAllocator
9+
#include "device_vector.cuh" // for DeviceUVector
10+
#include "io.h" // for ResourceHandler, MMAPFile
11+
#include "xgboost/string_view.h" // for StringView
1112

1213
namespace xgboost::common {
1314
/**
@@ -29,6 +30,22 @@ class CudaMallocResource : public ResourceHandler {
2930
void Resize(std::size_t n_bytes) { this->storage_.resize(n_bytes); }
3031
};
3132

33+
class CudaPinnedResource : public ResourceHandler {
34+
std::vector<std::byte, cuda_impl::SamAllocator<std::byte>> storage_;
35+
36+
void Clear() noexcept(true) { this->Resize(0); }
37+
38+
public:
39+
explicit CudaPinnedResource(std::size_t n_bytes) : ResourceHandler{kCudaHostCache} {
40+
this->Resize(n_bytes);
41+
}
42+
~CudaPinnedResource() noexcept(true) override { this->Clear(); }
43+
44+
[[nodiscard]] void* Data() override { return storage_.data(); }
45+
[[nodiscard]] std::size_t Size() const override { return storage_.size(); }
46+
void Resize(std::size_t n_bytes) { this->storage_.resize(n_bytes); }
47+
};
48+
3249
class CudaMmapResource : public ResourceHandler {
3350
std::unique_ptr<MMAPFile, std::function<void(MMAPFile*)>> handle_;
3451
std::size_t n_;

src/data/ellpack_page.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ size_t EllpackPageImpl::Copy(Context const* ctx, EllpackPageImpl const* page, bs
404404
bst_idx_t num_elements = page->n_rows * page->row_stride;
405405
CHECK_EQ(this->row_stride, page->row_stride);
406406
CHECK_EQ(NumSymbols(), page->NumSymbols());
407-
CHECK_GE(n_rows * row_stride, offset + num_elements);
407+
CHECK_GE(this->n_rows * this->row_stride, offset + num_elements);
408408
if (page == this) {
409409
LOG(FATAL) << "Concatenating the same Ellpack.";
410410
return this->n_rows * this->row_stride;
@@ -542,7 +542,10 @@ void EllpackPageImpl::CreateHistIndices(DeviceOrd device,
542542
// Return the number of rows contained in this page.
543543
[[nodiscard]] bst_idx_t EllpackPageImpl::Size() const { return n_rows; }
544544

545-
std::size_t EllpackPageImpl::MemCostBytes() const { return this->gidx_buffer.size_bytes(); }
545+
std::size_t EllpackPageImpl::MemCostBytes() const {
546+
return this->gidx_buffer.size_bytes() + sizeof(this->n_rows) + sizeof(this->is_dense) +
547+
sizeof(this->row_stride) + sizeof(this->base_rowid);
548+
}
546549

547550
EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
548551
DeviceOrd device, common::Span<FeatureType const> feature_types) const {

src/data/ellpack_page.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ struct EllpackDeviceAccessor {
6666
min_fvalue = cuts->min_vals_.ConstHostSpan();
6767
}
6868
}
69+
6970
/**
7071
* @brief Given a row index and a feature index, returns the corresponding cut value.
7172
*
@@ -75,7 +76,7 @@ struct EllpackDeviceAccessor {
7576
* local to the current batch.
7677
*/
7778
template <bool global_ridx = true>
78-
[[nodiscard]] __device__ bst_bin_t GetBinIndex(size_t ridx, size_t fidx) const {
79+
[[nodiscard]] __device__ bst_bin_t GetBinIndex(bst_idx_t ridx, size_t fidx) const {
7980
if (global_ridx) {
8081
ridx -= base_rowid;
8182
}
@@ -114,7 +115,7 @@ struct EllpackDeviceAccessor {
114115
return idx;
115116
}
116117

117-
[[nodiscard]] __device__ float GetFvalue(size_t ridx, size_t fidx) const {
118+
[[nodiscard]] __device__ float GetFvalue(bst_idx_t ridx, size_t fidx) const {
118119
auto gidx = GetBinIndex(ridx, fidx);
119120
if (gidx == -1) {
120121
return std::numeric_limits<float>::quiet_NaN();

src/data/ellpack_page_raw_format.cu

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ template <typename T>
3939
return false;
4040
}
4141

42-
auto ctx = Context{}.MakeCUDA(common::CurrentDevice());
43-
*vec = common::MakeFixedVecWithCudaMalloc<T>(&ctx, n);
42+
*vec = common::MakeFixedVecWithCudaMalloc<T>(n);
4443
dh::safe_cuda(cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, dh::DefaultStream()));
4544
return true;
4645
}
@@ -96,57 +95,21 @@ template <typename T>
9695
CHECK(this->cuts_->cut_values_.DeviceCanRead());
9796
impl->SetCuts(this->cuts_);
9897

99-
// Read vector
100-
Context ctx = Context{}.MakeCUDA(common::CurrentDevice());
101-
auto read_vec = [&] {
102-
common::NvtxScopedRange range{common::NvtxEventAttr{"read-vec", common::NvtxRgb{127, 255, 0}}};
103-
bst_idx_t n{0};
104-
RET_IF_NOT(fi->Read(&n));
105-
if (n == 0) {
106-
return true;
107-
}
108-
impl->gidx_buffer = common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(&ctx, n);
109-
RET_IF_NOT(fi->Read(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes()));
110-
return true;
111-
};
112-
RET_IF_NOT(read_vec());
113-
114-
RET_IF_NOT(fi->Read(&impl->n_rows));
115-
RET_IF_NOT(fi->Read(&impl->is_dense));
116-
RET_IF_NOT(fi->Read(&impl->row_stride));
117-
RET_IF_NOT(fi->Read(&impl->base_rowid));
118-
98+
fi->Read(page);
11999
dh::DefaultStream().Sync();
100+
120101
return true;
121102
}
122103

123104
[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page,
124105
EllpackHostCacheStream* fo) const {
125106
xgboost_NVTX_FN_RANGE();
126107

127-
bst_idx_t bytes{0};
128-
auto* impl = page.Impl();
129-
130-
// Write vector
131-
auto write_vec = [&] {
132-
common::NvtxScopedRange range{common::NvtxEventAttr{"write-vec", common::NvtxRgb{127, 255, 0}}};
133-
bst_idx_t n = impl->gidx_buffer.size();
134-
bytes += fo->Write(n);
135-
136-
if (!impl->gidx_buffer.empty()) {
137-
bytes += fo->Write(impl->gidx_buffer.data(), impl->gidx_buffer.size_bytes());
138-
}
139-
};
140-
141-
write_vec();
142-
143-
bytes += fo->Write(impl->n_rows);
144-
bytes += fo->Write(impl->is_dense);
145-
bytes += fo->Write(impl->row_stride);
146-
bytes += fo->Write(impl->base_rowid);
147-
108+
fo->Write(page);
148109
dh::DefaultStream().Sync();
149-
return bytes;
110+
111+
auto* impl = page.Impl();
112+
return impl->MemCostBytes();
150113
}
151114

152115
#undef RET_IF_NOT

0 commit comments

Comments
 (0)