Skip to content

Commit ec3f327

Browse files
authored
Add managed memory allocator. (dmlc#10711)
1 parent 8d7fe26 commit ec3f327

File tree

5 files changed

+119
-62
lines changed

5 files changed

+119
-62
lines changed

jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class DataIteratorProxy {
132132
bool cache_on_host_{true}; // TODO(Bobby): Make this optional.
133133

134134
template <typename T>
135-
using Alloc = xgboost::common::cuda::pinned_allocator<T>;
135+
using Alloc = xgboost::common::cuda_impl::pinned_allocator<T>;
136136
template <typename U>
137137
using HostVector = std::vector<U, Alloc<U>>;
138138

src/common/cuda_pinned_allocator.h

Lines changed: 80 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,114 @@
1-
/*!
2-
* Copyright 2022 by XGBoost Contributors
3-
* \file common.h
4-
* \brief cuda pinned allocator for usage with thrust containers
1+
/**
2+
* Copyright 2022-2024, XGBoost Contributors
3+
*
4+
* @brief cuda pinned allocator for usage with thrust containers
55
*/
66

77
#pragma once
88

9-
#include <cstddef>
10-
#include <limits>
9+
#include <cuda_runtime.h>
1110

12-
#include "common.h"
11+
#include <cstddef> // for size_t
12+
#include <limits> // for numeric_limits
1313

14-
namespace xgboost {
15-
namespace common {
16-
namespace cuda {
14+
#include "common.h"
1715

16+
namespace xgboost::common::cuda_impl {
1817
// \p pinned_allocator is a CUDA-specific host memory allocator
1918
// that employs \c cudaMallocHost for allocation.
2019
//
2120
// This implementation is ported from the experimental/pinned_allocator
2221
// that Thrust used to provide.
2322
//
2423
// \see https://en.cppreference.com/w/cpp/memory/allocator
24+
2525
template <typename T>
26-
class pinned_allocator;
26+
struct PinnedAllocPolicy {
27+
using pointer = T*; // NOLINT: The type returned by address() / allocate()
28+
using const_pointer = const T*; // NOLINT: The type returned by address()
29+
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
30+
using value_type = T; // NOLINT: The type of the elements in the allocator
31+
32+
size_type max_size() const { // NOLINT
33+
return std::numeric_limits<size_type>::max() / sizeof(value_type);
34+
}
2735

28-
template <>
29-
class pinned_allocator<void> {
30-
public:
31-
using value_type = void; // NOLINT: The type of the elements in the allocator
32-
using pointer = void*; // NOLINT: The type returned by address() / allocate()
33-
using const_pointer = const void*; // NOLINT: The type returned by address()
34-
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
35-
using difference_type = std::ptrdiff_t; // NOLINT: The type of the distance between two pointers
36+
pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT
37+
if (cnt > this->max_size()) {
38+
throw std::bad_alloc{};
39+
} // end if
3640

37-
template <typename U>
38-
struct rebind { // NOLINT
39-
using other = pinned_allocator<U>; // NOLINT: The rebound type
40-
};
41-
};
41+
pointer result(nullptr);
42+
dh::safe_cuda(cudaMallocHost(reinterpret_cast<void**>(&result), cnt * sizeof(value_type)));
43+
return result;
44+
}
4245

46+
void deallocate(pointer p, size_type) { dh::safe_cuda(cudaFreeHost(p)); } // NOLINT
47+
};
4348

4449
template <typename T>
45-
class pinned_allocator {
50+
struct ManagedAllocPolicy {
51+
using pointer = T*; // NOLINT: The type returned by address() / allocate()
52+
using const_pointer = const T*; // NOLINT: The type returned by address()
53+
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
54+
using value_type = T; // NOLINT: The type of the elements in the allocator
55+
56+
size_type max_size() const { // NOLINT
57+
return std::numeric_limits<size_type>::max() / sizeof(value_type);
58+
}
59+
60+
pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT
61+
if (cnt > this->max_size()) {
62+
throw std::bad_alloc{};
63+
} // end if
64+
65+
pointer result(nullptr);
66+
dh::safe_cuda(cudaMallocManaged(reinterpret_cast<void**>(&result), cnt * sizeof(value_type)));
67+
return result;
68+
}
69+
70+
void deallocate(pointer p, size_type) { dh::safe_cuda(cudaFree(p)); } // NOLINT
71+
};
72+
73+
template <typename T, template <typename> typename Policy>
74+
class CudaHostAllocatorImpl : public Policy<T> { // NOLINT
4675
public:
47-
using value_type = T; // NOLINT: The type of the elements in the allocator
48-
using pointer = T*; // NOLINT: The type returned by address() / allocate()
49-
using const_pointer = const T*; // NOLINT: The type returned by address()
50-
using reference = T&; // NOLINT: The parameter type for address()
51-
using const_reference = const T&; // NOLINT: The parameter type for address()
52-
using size_type = std::size_t; // NOLINT: The type used for the size of the allocation
76+
using value_type = typename Policy<T>::value_type; // NOLINT
77+
using pointer = typename Policy<T>::pointer; // NOLINT
78+
using const_pointer = typename Policy<T>::const_pointer; // NOLINT
79+
using size_type = typename Policy<T>::size_type; // NOLINT
80+
81+
using reference = T&; // NOLINT: The parameter type for address()
82+
using const_reference = const T&; // NOLINT: The parameter type for address()
83+
5384
using difference_type = std::ptrdiff_t; // NOLINT: The type of the distance between two pointers
5485

5586
template <typename U>
56-
struct rebind { // NOLINT
57-
using other = pinned_allocator<U>; // NOLINT: The rebound type
87+
struct rebind { // NOLINT
88+
using other = CudaHostAllocatorImpl<U, Policy>; // NOLINT: The rebound type
5889
};
5990

60-
XGBOOST_DEVICE inline pinned_allocator() {}; // NOLINT: host/device markup ignored on defaulted functions
61-
XGBOOST_DEVICE inline ~pinned_allocator() {} // NOLINT: host/device markup ignored on defaulted functions
62-
XGBOOST_DEVICE inline pinned_allocator(pinned_allocator const&) {} // NOLINT: host/device markup ignored on defaulted functions
91+
CudaHostAllocatorImpl() = default;
92+
~CudaHostAllocatorImpl() = default;
93+
CudaHostAllocatorImpl(CudaHostAllocatorImpl const&) = default;
6394

64-
pinned_allocator& operator=(pinned_allocator const& that) = default;
65-
pinned_allocator& operator=(pinned_allocator&& that) = default;
95+
CudaHostAllocatorImpl& operator=(CudaHostAllocatorImpl const& that) = default;
96+
CudaHostAllocatorImpl& operator=(CudaHostAllocatorImpl&& that) = default;
6697

6798
template <typename U>
68-
XGBOOST_DEVICE inline pinned_allocator(pinned_allocator<U> const&) {} // NOLINT
69-
70-
XGBOOST_DEVICE inline pointer address(reference r) { return &r; } // NOLINT
71-
XGBOOST_DEVICE inline const_pointer address(const_reference r) { return &r; } // NOLINT
72-
73-
inline pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT
74-
if (cnt > this->max_size()) { throw std::bad_alloc(); } // end if
99+
CudaHostAllocatorImpl(CudaHostAllocatorImpl<U, Policy> const&) {} // NOLINT
75100

76-
pointer result(nullptr);
77-
dh::safe_cuda(cudaMallocHost(reinterpret_cast<void**>(&result), cnt * sizeof(value_type)));
78-
return result;
79-
}
101+
pointer address(reference r) { return &r; } // NOLINT
102+
const_pointer address(const_reference r) { return &r; } // NOLINT
80103

81-
inline void deallocate(pointer p, size_type) { dh::safe_cuda(cudaFreeHost(p)); } // NOLINT
104+
bool operator==(CudaHostAllocatorImpl const& x) const { return true; }
82105

83-
inline size_type max_size() const { return (std::numeric_limits<size_type>::max)() / sizeof(T); } // NOLINT
106+
bool operator!=(CudaHostAllocatorImpl const& x) const { return !operator==(x); }
107+
};
84108

85-
XGBOOST_DEVICE inline bool operator==(pinned_allocator const& x) const { return true; }
109+
template <typename T>
110+
using pinned_allocator = CudaHostAllocatorImpl<T, PinnedAllocPolicy>; // NOLINT
86111

87-
XGBOOST_DEVICE inline bool operator!=(pinned_allocator const& x) const {
88-
return !operator==(x);
89-
}
90-
};
91-
} // namespace cuda
92-
} // namespace common
93-
} // namespace xgboost
112+
template <typename T>
113+
using managed_allocator = CudaHostAllocatorImpl<T, ManagedAllocPolicy>; // NOLINT
114+
} // namespace xgboost::common::cuda_impl

src/data/ellpack_page_source.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
namespace xgboost::data {
2222
struct EllpackHostCache {
23-
thrust::host_vector<std::int8_t, common::cuda::pinned_allocator<std::int8_t>> cache;
23+
thrust::host_vector<std::int8_t, common::cuda_impl::pinned_allocator<std::int8_t>> cache;
2424

2525
void Resize(std::size_t n, dh::CUDAStreamView stream) {
2626
stream.Sync(); // Prevent partial copy inside resize.

src/tree/gpu_hist/evaluate_splits.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct CatAccessor {
5757
class GPUHistEvaluator {
5858
using CatST = common::CatBitField::value_type; // categorical storage type
5959
// use pinned memory to stage the categories, used for sort based splits.
60-
using Alloc = xgboost::common::cuda::pinned_allocator<CatST>;
60+
using Alloc = xgboost::common::cuda_impl::pinned_allocator<CatST>;
6161

6262
private:
6363
TreeEvaluator tree_evaluator_;
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/**
2+
* Copyright 2024, XGBoost Contributors
3+
*/
4+
#include <gtest/gtest.h>
5+
#include <xgboost/context.h> // for Context
6+
7+
#include <vector>
8+
9+
#include "../../../src/common/cuda_pinned_allocator.h"
10+
#include "../../../src/common/device_helpers.cuh" // for DefaultStream
11+
#include "../../../src/common/numeric.h" // for Iota
12+
13+
namespace xgboost {
14+
TEST(CudaHostMalloc, Pinned) {
15+
std::vector<float, common::cuda_impl::pinned_allocator<float>> vec;
16+
vec.resize(10);
17+
ASSERT_EQ(vec.size(), 10);
18+
Context ctx;
19+
common::Iota(&ctx, vec.begin(), vec.end(), 0);
20+
float k = 0;
21+
for (auto v : vec) {
22+
ASSERT_EQ(v, k);
23+
++k;
24+
}
25+
}
26+
27+
TEST(CudaHostMalloc, Managed) {
28+
std::vector<float, common::cuda_impl::managed_allocator<float>> vec;
29+
vec.resize(10);
30+
#if defined(__linux__)
31+
dh::safe_cuda(
32+
cudaMemPrefetchAsync(vec.data(), vec.size() * sizeof(float), 0, dh::DefaultStream()));
33+
#endif
34+
dh::DefaultStream().Sync();
35+
}
36+
} // namespace xgboost

0 commit comments

Comments
 (0)