diff --git a/Src/Base/AMReX_GpuDevice.H b/Src/Base/AMReX_GpuDevice.H index d291f05300..111c0d3a83 100644 --- a/Src/Base/AMReX_GpuDevice.H +++ b/Src/Base/AMReX_GpuDevice.H @@ -16,6 +16,7 @@ #include #include #include +#include #define AMREX_GPU_MAX_STREAMS 8 @@ -46,8 +47,24 @@ using gpuDeviceProp_t = cudaDeviceProp; } #endif +namespace amrex { + class Arena; +} + namespace amrex::Gpu { +#ifdef AMREX_USE_GPU +class StreamManager { + gpuStream_t m_stream; + std::mutex m_mutex; + Vector> m_free_wait_list; +public: + [[nodiscard]] gpuStream_t& get (); + void sync (); + void stream_free (Arena* arena, void* mem); +}; +#endif + class Device { @@ -57,14 +74,16 @@ public: static void Finalize (); #if defined(AMREX_USE_GPU) - static gpuStream_t gpuStream () noexcept { return gpu_stream[OpenMP::get_thread_num()]; } + static gpuStream_t gpuStream () noexcept { + return gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].get(); + } #ifdef AMREX_USE_CUDA /** for backward compatibility */ - static cudaStream_t cudaStream () noexcept { return gpu_stream[OpenMP::get_thread_num()]; } + static cudaStream_t cudaStream () noexcept { return gpuStream(); } #endif #ifdef AMREX_USE_SYCL - static sycl::queue& streamQueue () noexcept { return *(gpu_stream[OpenMP::get_thread_num()].queue); } - static sycl::queue& streamQueue (int i) noexcept { return *(gpu_stream_pool[i].queue); } + static sycl::queue& streamQueue () noexcept { return *(gpuStream().queue); } + static sycl::queue& streamQueue (int i) noexcept { return *(gpu_stream_pool[i].get().queue); } #endif #endif @@ -104,6 +123,8 @@ public: */ static void streamSynchronizeAll () noexcept; + static void streamFree (Arena* arena, void* mem) noexcept; + #if defined(__CUDACC__) /** Generic graph selection. These should be called by users. */ static void startGraphRecording(bool first_iter, void* h_ptr, void* d_ptr, size_t sz); @@ -196,10 +217,10 @@ private: static AMREX_EXPORT dim3 numThreadsMin; static AMREX_EXPORT dim3 numBlocksOverride, numThreadsOverride; - static AMREX_EXPORT Vector gpu_stream_pool; // The size of this is max_gpu_stream - // The non-owning gpu_stream is used to store the current stream that will be used. - // gpu_stream is a vector so that it's thread safe to write to it. - static AMREX_EXPORT Vector gpu_stream; // The size of this is omp_max_threads + static AMREX_EXPORT Vector gpu_stream_pool; // The size of this is max_gpu_stream + // The non-owning gpu_stream_index is used to store the current stream index that will be used. + // gpu_stream_index is a vector so that it's thread safe to write to it. + static AMREX_EXPORT Vector gpu_stream_index; // The size of this is omp_max_threads static AMREX_EXPORT gpuDeviceProp_t device_prop; static AMREX_EXPORT int memory_pools_supported; static AMREX_EXPORT unsigned int max_blocks_per_launch; @@ -208,6 +229,8 @@ private: static AMREX_EXPORT std::unique_ptr sycl_context; static AMREX_EXPORT std::unique_ptr sycl_device; #endif + + friend StreamManager; #endif }; @@ -245,6 +268,12 @@ streamSynchronizeAll () noexcept Device::streamSynchronizeAll(); } +inline void +streamFree (Arena* arena, void* mem) noexcept +{ + Device::streamFree(arena, mem); +} + #ifdef AMREX_USE_GPU inline void diff --git a/Src/Base/AMReX_GpuDevice.cpp b/Src/Base/AMReX_GpuDevice.cpp index e7586316f2..416479e07a 100644 --- a/Src/Base/AMReX_GpuDevice.cpp +++ b/Src/Base/AMReX_GpuDevice.cpp @@ -1,4 +1,5 @@ +#include #include #include #include @@ -97,10 +98,10 @@ dim3 Device::numThreadsOverride = dim3(0, 0, 0); dim3 Device::numBlocksOverride = dim3(0, 0, 0); unsigned int Device::max_blocks_per_launch = 2560; -Vector Device::gpu_stream_pool; -Vector Device::gpu_stream; -gpuDeviceProp_t Device::device_prop; -int Device::memory_pools_supported = 0; +Vector Device::gpu_stream_pool; +Vector Device::gpu_stream_index; +gpuDeviceProp_t Device::device_prop; +int Device::memory_pools_supported = 0; constexpr int Device::warp_size; @@ -141,6 +142,64 @@ namespace { } } +[[nodiscard]] gpuStream_t& +StreamManager::get () { + return m_stream; +} + +void +StreamManager::sync () { + decltype(m_free_wait_list) new_empty_wait_list{}; + + { + // lock mutex before accessing and modifying member variables + std::lock_guard lock(m_mutex); + m_free_wait_list.swap(new_empty_wait_list); + } + // unlock mutex before stream sync and memory free + // to avoid deadlocks from the CArena mutex + + // actual stream sync +#ifdef AMREX_USE_SYCL + try { + m_stream.queue->wait_and_throw(); + } catch (sycl::exception const& ex) { + amrex::Abort(std::string("streamSynchronize: ")+ex.what()+"!!!!!"); + } +#else + AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL(hipStreamSynchronize(m_stream));, + AMREX_CUDA_SAFE_CALL(cudaStreamSynchronize(m_stream)); ) +#endif + + // synconizing the stream may have taken a long time and + // there may be new kernels launched already, so we free memory + // according to the state from before the stream was synced + + for (auto [arena, mem] : new_empty_wait_list) { + arena->free(mem); + } +} + +void +StreamManager::stream_free (Arena* arena, void* mem) { + if (arena->isDeviceAccessible()) { + std::size_t free_wait_list_size = 0; + { + // lock mutex before accessing and modifying member variables + std::lock_guard lock(m_mutex); + m_free_wait_list.emplace_back(arena, mem); + free_wait_list_size = m_free_wait_list.size(); + } + // Limit the number of memory allocations in m_free_wait_list + // in case the stream is never synchronized + if (free_wait_list_size > 100) { + sync(); + } + } else { + arena->free(mem); + } +} + #endif void @@ -384,24 +443,25 @@ void Device::Finalize () { #ifdef AMREX_USE_GPU + streamSynchronizeAll(); Device::profilerStop(); #ifdef AMREX_USE_SYCL for (auto& s : gpu_stream_pool) { - delete s.queue; - s.queue = nullptr; + delete s.get().queue; + s.get().queue = nullptr; } sycl_context.reset(); sycl_device.reset(); #else for (int i = 0; i < max_gpu_streams; ++i) { - AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL( hipStreamDestroy(gpu_stream_pool[i]));, - AMREX_CUDA_SAFE_CALL(cudaStreamDestroy(gpu_stream_pool[i])); ); + AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL( hipStreamDestroy(gpu_stream_pool[i].get()));, + AMREX_CUDA_SAFE_CALL(cudaStreamDestroy(gpu_stream_pool[i].get())); ); } #endif - gpu_stream.clear(); + gpu_stream_index.clear(); #ifdef AMREX_USE_ACC amrex_finalize_acc(); @@ -417,7 +477,10 @@ Device::initialize_gpu (bool minimal) #ifdef AMREX_USE_GPU - gpu_stream_pool.resize(max_gpu_streams); + if (gpu_stream_pool.size() != max_gpu_streams) { + // no copy/move constructor for std::mutex + gpu_stream_pool = Vector(max_gpu_streams); + } #ifdef AMREX_USE_HIP @@ -430,7 +493,7 @@ Device::initialize_gpu (bool minimal) // AMD devices do not support shared cache banking. for (int i = 0; i < max_gpu_streams; ++i) { - AMREX_HIP_SAFE_CALL(hipStreamCreate(&gpu_stream_pool[i])); + AMREX_HIP_SAFE_CALL(hipStreamCreate(&gpu_stream_pool[i].get())); } #ifdef AMREX_GPU_STREAM_ALLOC_SUPPORT @@ -458,9 +521,9 @@ Device::initialize_gpu (bool minimal) #endif for (int i = 0; i < max_gpu_streams; ++i) { - AMREX_CUDA_SAFE_CALL(cudaStreamCreate(&gpu_stream_pool[i])); + AMREX_CUDA_SAFE_CALL(cudaStreamCreate(&gpu_stream_pool[i].get())); #ifdef AMREX_USE_ACC - acc_set_cuda_stream(i, gpu_stream_pool[i]); + acc_set_cuda_stream(i, gpu_stream_pool[i].get()); #endif } @@ -473,7 +536,7 @@ Device::initialize_gpu (bool minimal) sycl_device = std::make_unique(gpu_devices[device_id]); sycl_context = std::make_unique(*sycl_device, amrex_sycl_error_handler); for (int i = 0; i < max_gpu_streams; ++i) { - gpu_stream_pool[i].queue = new sycl::queue(*sycl_context, *sycl_device, + gpu_stream_pool[i].get().queue = new sycl::queue(*sycl_context, *sycl_device, sycl::property_list{sycl::property::queue::in_order{}}); } } @@ -556,7 +619,7 @@ Device::initialize_gpu (bool minimal) } #endif - gpu_stream.resize(OpenMP::get_max_threads(), gpu_stream_pool[0]); + gpu_stream_index.resize(OpenMP::get_max_threads(), 0); ParmParse pp("device"); @@ -626,8 +689,13 @@ int Device::numDevicePartners () noexcept int Device::streamIndex (gpuStream_t s) noexcept { - auto it = std::find(std::begin(gpu_stream_pool), std::end(gpu_stream_pool), s); - return static_cast(std::distance(std::begin(gpu_stream_pool), it)); + const int N = gpu_stream_pool.size(); + for (int i = 0; i < N ; ++i) { + if (gpu_stream_pool[i].get() == s) { + return i; + } + } + return N; } #endif @@ -636,7 +704,7 @@ Device::setStreamIndex (int idx) noexcept { amrex::ignore_unused(idx); #ifdef AMREX_USE_GPU - gpu_stream[OpenMP::get_thread_num()] = gpu_stream_pool[idx % max_gpu_streams]; + gpu_stream_index[OpenMP::get_thread_num()] = idx % max_gpu_streams; #ifdef AMREX_USE_ACC amrex_set_acc_stream(idx % max_gpu_streams); #endif @@ -647,16 +715,16 @@ Device::setStreamIndex (int idx) noexcept gpuStream_t Device::resetStream () noexcept { - gpuStream_t r = gpu_stream[OpenMP::get_thread_num()]; - gpu_stream[OpenMP::get_thread_num()] = gpu_stream_pool[0]; + gpuStream_t r = gpuStream(); + gpu_stream_index[OpenMP::get_thread_num()] = 0; return r; } gpuStream_t Device::setStream (gpuStream_t s) noexcept { - gpuStream_t r = gpu_stream[OpenMP::get_thread_num()]; - gpu_stream[OpenMP::get_thread_num()] = s; + gpuStream_t r = gpuStream(); + gpu_stream_index[OpenMP::get_thread_num()] = streamIndex(s); return r; } #endif @@ -665,9 +733,9 @@ void Device::synchronize () noexcept { #ifdef AMREX_USE_SYCL - for (auto const& s : gpu_stream_pool) { + for (auto& s : gpu_stream_pool) { try { - s.queue->wait_and_throw(); + s.get().queue->wait_and_throw(); } catch (sycl::exception const& ex) { amrex::Abort(std::string("synchronize: ")+ex.what()+"!!!!!"); } @@ -681,16 +749,8 @@ Device::synchronize () noexcept void Device::streamSynchronize () noexcept { -#ifdef AMREX_USE_SYCL - auto& q = streamQueue(); - try { - q.wait_and_throw(); - } catch (sycl::exception const& ex) { - amrex::Abort(std::string("streamSynchronize: ")+ex.what()+"!!!!!"); - } -#else - AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL(hipStreamSynchronize(gpuStream()));, - AMREX_CUDA_SAFE_CALL(cudaStreamSynchronize(gpuStream())); ) +#ifdef AMREX_USE_GPU + gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].sync(); #endif } @@ -698,14 +758,19 @@ void Device::streamSynchronizeAll () noexcept { #ifdef AMREX_USE_GPU -#ifdef AMREX_USE_SYCL - Device::synchronize(); -#else - for (auto const& s : gpu_stream_pool) { - AMREX_HIP_OR_CUDA( AMREX_HIP_SAFE_CALL(hipStreamSynchronize(s));, - AMREX_CUDA_SAFE_CALL(cudaStreamSynchronize(s)); ) + for (auto& s : gpu_stream_pool) { + s.sync(); } #endif +} + +void +Device::streamFree (Arena* arena, void* mem) noexcept +{ +#ifdef AMREX_USE_GPU + gpu_stream_pool[gpu_stream_index[OpenMP::get_thread_num()]].stream_free(arena, mem); +#else + arena->free(mem); #endif } diff --git a/Src/Base/AMReX_PODVector.H b/Src/Base/AMReX_PODVector.H index f089ff117a..e811a10179 100644 --- a/Src/Base/AMReX_PODVector.H +++ b/Src/Base/AMReX_PODVector.H @@ -834,6 +834,18 @@ namespace amrex std::swap(static_cast(a_vector), static_cast(*this)); } + void stream_free () noexcept + { + if (m_data != nullptr) { + if constexpr (IsArenaAllocator::value) { + Gpu::streamFree(Allocator::arena(), m_data); + } else { + deallocate(m_data, capacity()); + } + m_data = nullptr; + } + } + private: void reserve_doit (size_type a_capacity) { diff --git a/Src/Base/AMReX_Scan.H b/Src/Base/AMReX_Scan.H index d150f77a25..84d49d0f41 100644 --- a/Src/Base/AMReX_Scan.H +++ b/Src/Base/AMReX_Scan.H @@ -926,10 +926,14 @@ T PrefixSum (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum = ret } }); - Gpu::streamSynchronize(); - AMREX_GPU_ERROR_CHECK(); + if (totalsum_p) { + Gpu::streamSynchronize(); + AMREX_GPU_ERROR_CHECK(); - The_Arena()->free(tile_state_p); + The_Arena()->free(tile_state_p); + } else { + Gpu::streamFree(The_Arena(), tile_state_p); + } T ret = (a_ret_sum) ? *totalsum_p : T(0); if (totalsum_p) { The_Pinned_Arena()->free(totalsum_p); } diff --git a/Src/Particle/AMReX_ParticleUtil.H b/Src/Particle/AMReX_ParticleUtil.H index ac91573f85..e88b0cabe5 100644 --- a/Src/Particle/AMReX_ParticleUtil.H +++ b/Src/Particle/AMReX_ParticleUtil.H @@ -564,7 +564,8 @@ partitionParticles (PTile& ptile, ParFunc const& is_left) } }); - Gpu::streamSynchronize(); // for index_left and index_right deallocation + index_left.stream_free(); + index_right.stream_free(); return num_left; }