|
| 1 | +/*! |
| 2 | + * Copyright 2022 by XGBoost Contributors |
| 3 | + * \file common.h |
| 4 | + * \brief cuda pinned allocator for usage with thrust containers |
| 5 | + */ |
| 6 | + |
| 7 | +#pragma once |
| 8 | + |
| 9 | +#include <cstddef> |
| 10 | +#include <limits> |
| 11 | + |
| 12 | +#include "common.h" |
| 13 | + |
| 14 | +namespace xgboost { |
| 15 | +namespace common { |
| 16 | +namespace cuda { |
| 17 | + |
| 18 | +// \p pinned_allocator is a CUDA-specific host memory allocator |
| 19 | +// that employs \c cudaMallocHost for allocation. |
| 20 | +// |
| 21 | +// This implementation is ported from the experimental/pinned_allocator |
| 22 | +// that Thrust used to provide. |
| 23 | +// |
| 24 | +// \see https://en.cppreference.com/w/cpp/memory/allocator |
| 25 | +template <typename T> |
| 26 | +class pinned_allocator; |
| 27 | + |
| 28 | +template <> |
| 29 | +class pinned_allocator<void> { |
| 30 | + public: |
| 31 | + using value_type = void; // NOLINT: The type of the elements in the allocator |
| 32 | + using pointer = void*; // NOLINT: The type returned by address() / allocate() |
| 33 | + using const_pointer = const void*; // NOLINT: The type returned by address() |
| 34 | + using size_type = std::size_t; // NOLINT: The type used for the size of the allocation |
| 35 | + using difference_type = std::ptrdiff_t; // NOLINT: The type of the distance between two pointers |
| 36 | + |
| 37 | + template <typename U> |
| 38 | + struct rebind { // NOLINT |
| 39 | + using other = pinned_allocator<U>; // NOLINT: The rebound type |
| 40 | + }; |
| 41 | +}; |
| 42 | + |
| 43 | + |
| 44 | +template <typename T> |
| 45 | +class pinned_allocator { |
| 46 | + public: |
| 47 | + using value_type = T; // NOLINT: The type of the elements in the allocator |
| 48 | + using pointer = T*; // NOLINT: The type returned by address() / allocate() |
| 49 | + using const_pointer = const T*; // NOLINT: The type returned by address() |
| 50 | + using reference = T&; // NOLINT: The parameter type for address() |
| 51 | + using const_reference = const T&; // NOLINT: The parameter type for address() |
| 52 | + using size_type = std::size_t; // NOLINT: The type used for the size of the allocation |
| 53 | + using difference_type = std::ptrdiff_t; // NOLINT: The type of the distance between two pointers |
| 54 | + |
| 55 | + template <typename U> |
| 56 | + struct rebind { // NOLINT |
| 57 | + using other = pinned_allocator<U>; // NOLINT: The rebound type |
| 58 | + }; |
| 59 | + |
| 60 | + XGBOOST_DEVICE inline pinned_allocator() {}; // NOLINT: host/device markup ignored on defaulted functions |
| 61 | + XGBOOST_DEVICE inline ~pinned_allocator() {} // NOLINT: host/device markup ignored on defaulted functions |
| 62 | + XGBOOST_DEVICE inline pinned_allocator(pinned_allocator const&) {} // NOLINT: host/device markup ignored on defaulted functions |
| 63 | + |
| 64 | + |
| 65 | + template <typename U> |
| 66 | + XGBOOST_DEVICE inline pinned_allocator(pinned_allocator<U> const&) {} // NOLINT |
| 67 | + |
| 68 | + XGBOOST_DEVICE inline pointer address(reference r) { return &r; } // NOLINT |
| 69 | + XGBOOST_DEVICE inline const_pointer address(const_reference r) { return &r; } // NOLINT |
| 70 | + |
| 71 | + inline pointer allocate(size_type cnt, const_pointer = nullptr) { // NOLINT |
| 72 | + if (cnt > this->max_size()) { throw std::bad_alloc(); } // end if |
| 73 | + |
| 74 | + pointer result(nullptr); |
| 75 | + dh::safe_cuda(cudaMallocHost(reinterpret_cast<void**>(&result), cnt * sizeof(value_type))); |
| 76 | + return result; |
| 77 | + } |
| 78 | + |
| 79 | + inline void deallocate(pointer p, size_type) { dh::safe_cuda(cudaFreeHost(p)); } // NOLINT |
| 80 | + |
| 81 | + inline size_type max_size() const { return (std::numeric_limits<size_type>::max)() / sizeof(T); } // NOLINT |
| 82 | + |
| 83 | + XGBOOST_DEVICE inline bool operator==(pinned_allocator const& x) const { return true; } |
| 84 | + |
| 85 | + XGBOOST_DEVICE inline bool operator!=(pinned_allocator const& x) const { |
| 86 | + return !operator==(x); |
| 87 | + } |
| 88 | +}; |
| 89 | +} // namespace cuda |
| 90 | +} // namespace common |
| 91 | +} // namespace xgboost |
0 commit comments