|
| 1 | +// |
| 2 | +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. |
| 3 | +// |
| 4 | +// Permission is hereby granted, free of charge, to any person obtaining a copy |
| 5 | +// of this software and associated documentation files (the "Software"), to deal |
| 6 | +// in the Software without restriction, including without limitation the rights |
| 7 | +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 8 | +// copies of the Software, and to permit persons to whom the Software is |
| 9 | +// furnished to do so, subject to the following conditions: |
| 10 | +// |
| 11 | +// The above copyright notice and this permission notice shall be included in |
| 12 | +// all copies or substantial portions of the Software. |
| 13 | +// |
| 14 | +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 15 | +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 16 | +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 17 | +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 18 | +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 19 | +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
| 20 | +// THE SOFTWARE. |
| 21 | +// |
| 22 | + |
| 23 | +#pragma once |
| 24 | + |
| 25 | +#include <Orochi/OrochiUtils.h> |
| 26 | +#include <utility> |
| 27 | + |
| 28 | +namespace Oro |
| 29 | +{ |
| 30 | + |
| 31 | +/// @brief A helper function that casts an address of a pointer to the device memory to a void pointer to be used as an argument for kernel calls. |
| 32 | +/// @tparam T The type of the element stored in the device memory. |
| 33 | +/// @param ptr The address of a pointer to the device memory. |
| 34 | +/// @return A void pointer. |
| 35 | +template<typename T> |
| 36 | +void* arg_cast( T* const* ptr ) noexcept |
| 37 | +{ |
| 38 | + return reinterpret_cast<void*>( const_cast<T**>( ptr ) ); |
| 39 | +} |
| 40 | + |
| 41 | +template<typename T> |
| 42 | +class GpuMemory final |
| 43 | +{ |
| 44 | + public: |
| 45 | + GpuMemory() = default; |
| 46 | + |
| 47 | + /// @brief Allocate the device memory with the given size. |
| 48 | + /// @param init_size The initial size which represents the number of elements. |
| 49 | + explicit GpuMemory( const size_t init_size ) |
| 50 | + { |
| 51 | + OrochiUtils::malloc( m_data, init_size ); |
| 52 | + |
| 53 | + m_size = init_size; |
| 54 | + m_capacity = init_size; |
| 55 | + } |
| 56 | + |
| 57 | + GpuMemory( const GpuMemory& ) = delete; |
| 58 | + GpuMemory& operator=( const GpuMemory& other ) = delete; |
| 59 | + |
| 60 | + GpuMemory( GpuMemory&& other ) noexcept : m_data{ std::exchange( other.m_data, nullptr ) }, m_size{ std::exchange( other.m_size, 0ULL ) }, m_capacity{ std::exchange( other.m_capacity, 0ULL ) } {} |
| 61 | + |
| 62 | + GpuMemory& operator=( GpuMemory&& other ) noexcept |
| 63 | + { |
| 64 | + GpuMemory tmp( std::move( *this ) ); |
| 65 | + |
| 66 | + swap( *this, other ); |
| 67 | + |
| 68 | + return *this; |
| 69 | + } |
| 70 | + |
| 71 | + ~GpuMemory() |
| 72 | + { |
| 73 | + if( m_data ) |
| 74 | + { |
| 75 | + OrochiUtils::free( m_data ); |
| 76 | + m_data = nullptr; |
| 77 | + } |
| 78 | + m_size = 0ULL; |
| 79 | + m_capacity = 0ULL; |
| 80 | + } |
| 81 | + |
| 82 | + /// @brief Get the size of the device memory. |
| 83 | + /// @return The size of the device memory. |
| 84 | + size_t size() const noexcept { return m_size; } |
| 85 | + |
| 86 | + /// @brief Get the pointer to the device memory. |
| 87 | + /// @return The pointer to the device memory. |
| 88 | + T* ptr() const noexcept { return m_data; } |
| 89 | + |
| 90 | + /// @brief Get the address of the pointer to the device memory. Useful for passing arguments to the kernel call. |
| 91 | + /// @return The address of the pointer to the device memory. |
| 92 | + T* const* address() const noexcept { return &m_data; } |
| 93 | + |
| 94 | + /// @brief Resize the device memory. Its capacity is unchanged if the new size is smaller than the current one. |
| 95 | + /// The old data should be considered invalid to be used after the function is called unless @c copy is set to True. |
| 96 | + /// @param new_size The new memory size after the function is called. |
| 97 | + /// @param copy If true, the function will copy the data to the newly created memory space as well. |
| 98 | + void resize( const size_t new_size, const bool copy = false ) noexcept |
| 99 | + { |
| 100 | + if( new_size <= m_capacity ) |
| 101 | + { |
| 102 | + m_size = new_size; |
| 103 | + return; |
| 104 | + } |
| 105 | + |
| 106 | + GpuMemory tmp( new_size ); |
| 107 | + |
| 108 | + if( copy ) |
| 109 | + { |
| 110 | + OrochiUtils::copyDtoD( tmp.m_data, m_data, m_size ); |
| 111 | + } |
| 112 | + |
| 113 | + *this = std::move( tmp ); |
| 114 | + } |
| 115 | + |
| 116 | + /// @brief Asynchronous version of 'resize' using a given Orochi stream. |
| 117 | + /// @param new_size The new memory size after the function is called. |
| 118 | + /// @param copy If true, the function will copy the data to the newly created memory space as well. |
| 119 | + /// @param stream The Orochi stream used for the underlying operations. |
| 120 | + void resizeAsync( const size_t new_size, const bool copy = false, oroStream stream = 0 ) noexcept |
| 121 | + { |
| 122 | + if( new_size <= m_capacity ) |
| 123 | + { |
| 124 | + m_size = new_size; |
| 125 | + return; |
| 126 | + } |
| 127 | + |
| 128 | + GpuMemory tmp( new_size ); |
| 129 | + |
| 130 | + if( copy ) |
| 131 | + { |
| 132 | + OrochiUtils::copyDtoDAsync( tmp.m_data, m_data, m_size, stream ); |
| 133 | + } |
| 134 | + |
| 135 | + *this = std::move( tmp ); |
| 136 | + } |
| 137 | + |
| 138 | + /// @brief Reset the memory space so that all bits inside are cleared to zero. |
| 139 | + void reset() noexcept { OrochiUtils::memset( m_data, 0, m_size * sizeof( T ) ); } |
| 140 | + |
| 141 | + /// @brief Asynchronous version of 'reset' using a given Orochi stream. |
| 142 | + /// @param stream The Orochi stream used for the underlying operations. |
| 143 | + void resetAsync( oroStream stream = 0 ) noexcept { OrochiUtils::memsetAsync( m_data, 0, m_size * sizeof( T ), stream ); } |
| 144 | + |
| 145 | + /// @brief Copy the data from device memory to host. |
| 146 | + /// @param host_ptr The host pointer. |
| 147 | + /// @param host_data_size The size of the host memory which represents the number of elements. |
| 148 | + void copyFromHost( const T* host_ptr, const size_t host_data_size ) noexcept |
| 149 | + { |
| 150 | + resize( host_data_size ); |
| 151 | + OrochiUtils::copyHtoD( m_data, host_ptr, host_data_size ); |
| 152 | + } |
| 153 | + |
| 154 | + /// @brief Get the content of the first element stored in the device memory. |
| 155 | + /// @return The content of the first element in the device memory. |
| 156 | + T getSingle() const noexcept |
| 157 | + { |
| 158 | + T result{}; |
| 159 | + |
| 160 | + OrochiUtils::copyDtoH( &result, m_data, 1ULL ); |
| 161 | + |
| 162 | + return result; |
| 163 | + } |
| 164 | + |
| 165 | + /// @brief Get all the data stored in the device memory. |
| 166 | + /// @return A vector which contains all the data stored in the device memory. |
| 167 | + std::vector<T> getData() const noexcept |
| 168 | + { |
| 169 | + std::vector<T> result{}; |
| 170 | + result.resize( m_size ); |
| 171 | + |
| 172 | + OrochiUtils::copyDtoH( result.data(), m_data, m_size ); |
| 173 | + |
| 174 | + return result; |
| 175 | + } |
| 176 | + |
| 177 | + private: |
| 178 | + static void swap( GpuMemory& lhs, GpuMemory& rhs ) noexcept |
| 179 | + { |
| 180 | + std::swap( lhs.m_data, rhs.m_data ); |
| 181 | + std::swap( lhs.m_size, rhs.m_size ); |
| 182 | + std::swap( lhs.m_capacity, rhs.m_capacity ); |
| 183 | + } |
| 184 | + |
| 185 | + T* m_data{ nullptr }; |
| 186 | + size_t m_size{ 0ULL }; |
| 187 | + size_t m_capacity{ 0ULL }; |
| 188 | +}; |
| 189 | + |
| 190 | +} // namespace Oro |
0 commit comments