From f3d621195981cb5f90cc9882b22556243530daf4 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Fri, 21 Mar 2025 13:09:49 +0000 Subject: [PATCH 1/8] [UR][CUDA] Allow cuda adapter objects to block full adapter teardown. In the cuda adapter the adapter struct itself is currently an extern global defined in adapter.cpp. This means fully tearing down the adapter is subject to the same destructor ordering as all other static and global variables, it's first in last out. This presents a problem because an application can declare a static sycl object like a buffer right up top before doing anything else, which results in the sycl object being destroyed after the cuda adapter struct. The UR spec doesn't put the onus on users to keep their parent object lifetimes in order, i.e. there is no statement about "the context you use to create a ur_mem_handle_t must not be released until after the mem_handle". It's assumed (by omission rather than explicitly) that adapters will have their objects keep a reference to any parent objects alive for the duration of their own lifetime. This change moves the cuda adapter structs ownership into a global shared_ptr, which allows child objects of the adapter to keep their own references to it alive past the point where its initial definition goes out of scope. Also adjusts how some other objects track parent object references so that the destructors correctly cascade back to the top: mem handle releases its context, which releases its adapter, which releases the platform + devices, etc. Fixes https://github.com/intel/llvm/issues/17450 --- .../Regression/static-buffer-dtor.cpp | 3 - .../source/adapters/cuda/adapter.cpp | 56 +++++++++---------- .../source/adapters/cuda/adapter.hpp | 28 +++++++++- .../source/adapters/cuda/context.cpp | 1 + .../source/adapters/cuda/context.hpp | 16 +++--- .../source/adapters/cuda/device.cpp | 2 +- .../source/adapters/cuda/memory.hpp | 3 + .../source/adapters/cuda/platform.cpp | 18 +++--- .../source/adapters/cuda/platform.hpp | 3 + unified-runtime/source/adapters/cuda/usm.cpp | 2 +- 10 files changed, 79 insertions(+), 53 deletions(-) diff --git a/sycl/test-e2e/Regression/static-buffer-dtor.cpp b/sycl/test-e2e/Regression/static-buffer-dtor.cpp index 8cc1ddce43501..474c2b756d3fe 100644 --- a/sycl/test-e2e/Regression/static-buffer-dtor.cpp +++ b/sycl/test-e2e/Regression/static-buffer-dtor.cpp @@ -21,9 +21,6 @@ // UNSUPPORTED: windows && arch-intel_gpu_bmg_g21 // UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17255 -// UNSUPPORTED: cuda -// UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17450 - #include int main() { diff --git a/unified-runtime/source/adapters/cuda/adapter.cpp b/unified-runtime/source/adapters/cuda/adapter.cpp index 3ea896bbd6cc3..d0083995542ea 100644 --- a/unified-runtime/source/adapters/cuda/adapter.cpp +++ b/unified-runtime/source/adapters/cuda/adapter.cpp @@ -8,19 +8,16 @@ // //===----------------------------------------------------------------------===// -#include - +#include "adapter.hpp" #include "common.hpp" -#include "logger/ur_logger.hpp" +#include "platform.hpp" #include "tracing.hpp" -struct ur_adapter_handle_t_ { - std::atomic RefCount = 0; - std::mutex Mutex; - struct cuda_tracing_context_t_ *TracingCtx = nullptr; - logger::Logger &logger; - ur_adapter_handle_t_(); -}; +#include + +namespace ur::cuda { +std::shared_ptr adapter; +} // namespace ur::cuda class ur_legacy_sink : public logger::Sink { public: @@ -43,28 +40,29 @@ class ur_legacy_sink : public logger::Sink { ur_adapter_handle_t_::ur_adapter_handle_t_() : logger(logger::get_logger("cuda", /*default_log_level*/ logger::Level::ERR)) { + Platform = std::make_unique(); - if (std::getenv("UR_LOG_CUDA") != nullptr) - return; - - if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr || - std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr) { + if (std::getenv("UR_LOG_CUDA") == nullptr && + (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr || + std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr)) { logger.setLegacySink(std::make_unique()); } } -ur_adapter_handle_t_ adapter{}; UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (NumEntries > 0 && phAdapters) { - std::lock_guard Lock{adapter.Mutex}; - if (adapter.RefCount++ == 0) { - adapter.TracingCtx = createCUDATracingContext(); - enableCUDATracing(adapter.TracingCtx); - } + static std::once_flag InitFlag; + std::call_once(InitFlag, [=]() { + ur::cuda::adapter = std::make_shared(); + }); + + std::lock_guard Lock{ur::cuda::adapter->Mutex}; + ur::cuda::adapter->TracingCtx = createCUDATracingContext(); + enableCUDATracing(ur::cuda::adapter->TracingCtx); - *phAdapters = &adapter; + *phAdapters = ur::cuda::adapter.get(); } if (pNumAdapters) { @@ -75,17 +73,17 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - adapter.RefCount++; + ur::cuda::adapter->RefCount++; return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - std::lock_guard Lock{adapter.Mutex}; - if (--adapter.RefCount == 0) { - disableCUDATracing(adapter.TracingCtx); - freeCUDATracingContext(adapter.TracingCtx); - adapter.TracingCtx = nullptr; + std::lock_guard Lock{ur::cuda::adapter->Mutex}; + if (--ur::cuda::adapter->RefCount == 0) { + disableCUDATracing(ur::cuda::adapter->TracingCtx); + freeCUDATracingContext(ur::cuda::adapter->TracingCtx); + ur::cuda::adapter.reset(); } return UR_RESULT_SUCCESS; } @@ -108,7 +106,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_ADAPTER_BACKEND_CUDA); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(adapter.RefCount.load()); + return ReturnValue(ur::cuda::adapter->RefCount.load()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/cuda/adapter.hpp b/unified-runtime/source/adapters/cuda/adapter.hpp index ba05b86502347..6cb9522cac10f 100644 --- a/unified-runtime/source/adapters/cuda/adapter.hpp +++ b/unified-runtime/source/adapters/cuda/adapter.hpp @@ -8,6 +8,30 @@ // //===----------------------------------------------------------------------===// -struct ur_adapter_handle_t_; +#include "logger/ur_logger.hpp" +#include "platform.hpp" +#include "tracing.hpp" +#include -extern ur_adapter_handle_t_ adapter; +#include +#include +#include + +// should maybe be an ifdef +#pragma once + +struct ur_platform_handle_t_; + +struct ur_adapter_handle_t_ { + std::atomic RefCount = 0; + std::mutex Mutex; + struct cuda_tracing_context_t_ *TracingCtx = nullptr; + logger::Logger &logger; + std::unique_ptr Platform; + ur_adapter_handle_t_(); +}; + +// Keep the global namespace'd +namespace ur::cuda { +extern std::shared_ptr adapter; +} // namespace ur::cuda diff --git a/unified-runtime/source/adapters/cuda/context.cpp b/unified-runtime/source/adapters/cuda/context.cpp index 8d1adc900da50..4e4f8720f7124 100644 --- a/unified-runtime/source/adapters/cuda/context.cpp +++ b/unified-runtime/source/adapters/cuda/context.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "context.hpp" +#include "platform.hpp" #include "usm.hpp" #include diff --git a/unified-runtime/source/adapters/cuda/context.hpp b/unified-runtime/source/adapters/cuda/context.hpp index b3b40123d2ea8..c35a57353e518 100644 --- a/unified-runtime/source/adapters/cuda/context.hpp +++ b/unified-runtime/source/adapters/cuda/context.hpp @@ -10,6 +10,7 @@ #pragma once #include +#include #include #include @@ -17,6 +18,7 @@ #include #include +#include "adapter.hpp" #include "common.hpp" #include "device.hpp" #include "umf_helpers.hpp" @@ -117,6 +119,10 @@ struct ur_context_handle_t_ { void operator()() { Function(UserData); } }; + // Retain an additional reference to the adapter as it keeps the devices + // alive, which end up being used (indirectly) by our destructor. + std::shared_ptr Adapter; + std::vector Devices; std::atomic_uint32_t RefCount; @@ -126,11 +132,8 @@ struct ur_context_handle_t_ { umf_memory_pool_handle_t MemoryPoolHost = nullptr; ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) - : Devices{Devs, Devs + NumDevices}, RefCount{1} { - for (auto &Dev : Devices) { - urDeviceRetain(Dev); - } - + : Adapter(ur::cuda::adapter), Devices{Devs, Devs + NumDevices}, + RefCount{1} { // Create UMF CUDA memory provider for the host memory // (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because // it is guaranteed to exist). @@ -145,9 +148,6 @@ struct ur_context_handle_t_ { if (MemoryProviderHost) { umfMemoryProviderDestroy(MemoryProviderHost); } - for (auto &Dev : Devices) { - urDeviceRelease(Dev); - } } void invokeExtendedDeleters() { diff --git a/unified-runtime/source/adapters/cuda/device.cpp b/unified-runtime/source/adapters/cuda/device.cpp index 3d413f1cdd715..b057bcf447f5e 100644 --- a/unified-runtime/source/adapters/cuda/device.cpp +++ b/unified-runtime/source/adapters/cuda/device.cpp @@ -1216,7 +1216,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Get list of platforms uint32_t NumPlatforms = 0; - ur_adapter_handle_t AdapterHandle = &adapter; + ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get(); ur_result_t Result = urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms); if (Result != UR_RESULT_SUCCESS) diff --git a/unified-runtime/source/adapters/cuda/memory.hpp b/unified-runtime/source/adapters/cuda/memory.hpp index f0fc14f864796..65f065b1ec3c4 100644 --- a/unified-runtime/source/adapters/cuda/memory.hpp +++ b/unified-runtime/source/adapters/cuda/memory.hpp @@ -393,6 +393,9 @@ struct ur_mem_handle_t_ { urMemRelease(std::get(Mem).Parent); return; } + if (LastQueueWritingToMemObj != nullptr) { + urQueueRelease(LastQueueWritingToMemObj); + } urContextRelease(Context); } diff --git a/unified-runtime/source/adapters/cuda/platform.cpp b/unified-runtime/source/adapters/cuda/platform.cpp index 953c655bedff5..a37d2d754b3ea 100644 --- a/unified-runtime/source/adapters/cuda/platform.cpp +++ b/unified-runtime/source/adapters/cuda/platform.cpp @@ -100,7 +100,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( return ReturnValue(UR_PLATFORM_BACKEND_CUDA); } case UR_PLATFORM_INFO_ADAPTER: { - return ReturnValue(&adapter); + return ReturnValue(&ur::cuda::adapter); } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; @@ -116,11 +116,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) { - try { static std::once_flag InitFlag; static uint32_t NumPlatforms = 1; - static ur_platform_handle_t_ Platform; UR_ASSERT(phPlatforms || pNumPlatforms, UR_RESULT_ERROR_INVALID_VALUE); UR_ASSERT(!phPlatforms || NumEntries > 0, UR_RESULT_ERROR_INVALID_SIZE); @@ -151,22 +149,24 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, // Use default stream to record base event counter UR_CHECK_ERROR(cuEventRecord(EvBase, 0)); - Platform.Devices.emplace_back( - new ur_device_handle_t_{Device, Context, EvBase, &Platform, + ur::cuda::adapter->Platform->Devices.emplace_back( + new ur_device_handle_t_{Device, Context, EvBase, + ur::cuda::adapter->Platform.get(), static_cast(i)}); } - UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(&Platform)); + UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools( + ur::cuda::adapter->Platform.get())); } catch (const std::bad_alloc &) { // Signal out-of-memory situation for (int i = 0; i < NumDevices; ++i) { - Platform.Devices.clear(); + ur::cuda::adapter->Platform->Devices.clear(); } Result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; } catch (ur_result_t Err) { // Clear and rethrow to allow retry for (int i = 0; i < NumDevices; ++i) { - Platform.Devices.clear(); + ur::cuda::adapter->Platform->Devices.clear(); } Result = Err; throw Err; @@ -182,7 +182,7 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, } if (phPlatforms != nullptr) { - *phPlatforms = &Platform; + *phPlatforms = ur::cuda::adapter->Platform.get(); } return Result; diff --git a/unified-runtime/source/adapters/cuda/platform.hpp b/unified-runtime/source/adapters/cuda/platform.hpp index 5da72057abe12..397a2e7a6f6e5 100644 --- a/unified-runtime/source/adapters/cuda/platform.hpp +++ b/unified-runtime/source/adapters/cuda/platform.hpp @@ -9,7 +9,10 @@ //===----------------------------------------------------------------------===// #pragma once +#include "device.hpp" #include + +#include #include struct ur_platform_handle_t_ { diff --git a/unified-runtime/source/adapters/cuda/usm.cpp b/unified-runtime/source/adapters/cuda/usm.cpp index 5ee1ee5aecbae..d8fac6fab2bf8 100644 --- a/unified-runtime/source/adapters/cuda/usm.cpp +++ b/unified-runtime/source/adapters/cuda/usm.cpp @@ -242,7 +242,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // cuda backend has only one platform containing all devices ur_platform_handle_t platform; - ur_adapter_handle_t AdapterHandle = &adapter; + ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get(); Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr); // get the device from the platform From a0d36a6d569ba89732e6d6dfab2b7b88d54539d0 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Fri, 21 Mar 2025 14:59:01 +0000 Subject: [PATCH 2/8] Fix platform adapter query. --- unified-runtime/source/adapters/cuda/platform.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unified-runtime/source/adapters/cuda/platform.cpp b/unified-runtime/source/adapters/cuda/platform.cpp index a37d2d754b3ea..6a9a2a691ab37 100644 --- a/unified-runtime/source/adapters/cuda/platform.cpp +++ b/unified-runtime/source/adapters/cuda/platform.cpp @@ -100,7 +100,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( return ReturnValue(UR_PLATFORM_BACKEND_CUDA); } case UR_PLATFORM_INFO_ADAPTER: { - return ReturnValue(&ur::cuda::adapter); + return ReturnValue(ur::cuda::adapter.get()); } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; From b9b80e1788d3e5a992330a3938a3fe03a53e5d6a Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Fri, 21 Mar 2025 16:57:03 +0000 Subject: [PATCH 3/8] Replace pragma once with ifdef. --- unified-runtime/source/adapters/cuda/adapter.hpp | 8 +++++--- unified-runtime/source/adapters/cuda/platform.hpp | 6 +++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/unified-runtime/source/adapters/cuda/adapter.hpp b/unified-runtime/source/adapters/cuda/adapter.hpp index 6cb9522cac10f..eab2e224f0eab 100644 --- a/unified-runtime/source/adapters/cuda/adapter.hpp +++ b/unified-runtime/source/adapters/cuda/adapter.hpp @@ -8,6 +8,9 @@ // //===----------------------------------------------------------------------===// +#ifndef UR_CUDA_ADAPTER_HPP_INCLUDED +#define UR_CUDA_ADAPTER_HPP_INCLUDED + #include "logger/ur_logger.hpp" #include "platform.hpp" #include "tracing.hpp" @@ -17,9 +20,6 @@ #include #include -// should maybe be an ifdef -#pragma once - struct ur_platform_handle_t_; struct ur_adapter_handle_t_ { @@ -35,3 +35,5 @@ struct ur_adapter_handle_t_ { namespace ur::cuda { extern std::shared_ptr adapter; } // namespace ur::cuda + +#endif // UR_CUDA_ADAPTER_HPP_INCLUDED diff --git a/unified-runtime/source/adapters/cuda/platform.hpp b/unified-runtime/source/adapters/cuda/platform.hpp index 397a2e7a6f6e5..8ecc19c3e9f61 100644 --- a/unified-runtime/source/adapters/cuda/platform.hpp +++ b/unified-runtime/source/adapters/cuda/platform.hpp @@ -7,7 +7,9 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#pragma once + +#ifndef UR_CUDA_PLATFORM_HPP_INCLUDED +#define UR_CUDA_PLATFORM_HPP_INCLUDED #include "device.hpp" #include @@ -18,3 +20,5 @@ struct ur_platform_handle_t_ { std::vector> Devices; }; + +#endif // UR_CUDA_PLATFORM_HPP_INCLUDED From 61b3bdb62af7ee67beb876783a6c3ca6a11f926c Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Tue, 25 Mar 2025 17:13:32 +0000 Subject: [PATCH 4/8] Fix cuda adapter ref count. --- unified-runtime/source/adapters/cuda/adapter.hpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unified-runtime/source/adapters/cuda/adapter.hpp b/unified-runtime/source/adapters/cuda/adapter.hpp index eab2e224f0eab..4fef3ab99cb2a 100644 --- a/unified-runtime/source/adapters/cuda/adapter.hpp +++ b/unified-runtime/source/adapters/cuda/adapter.hpp @@ -20,10 +20,8 @@ #include #include -struct ur_platform_handle_t_; - struct ur_adapter_handle_t_ { - std::atomic RefCount = 0; + std::atomic RefCount = 1; std::mutex Mutex; struct cuda_tracing_context_t_ *TracingCtx = nullptr; logger::Logger &logger; From 322ae889ba169dc51e3e18c132d895fcfd998df3 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Tue, 25 Mar 2025 17:24:36 +0000 Subject: [PATCH 5/8] Apply similar fix to hip adapter. --- .../source/adapters/hip/adapter.cpp | 31 ++++++++++--------- .../source/adapters/hip/adapter.hpp | 22 +++++++++++-- .../source/adapters/hip/context.hpp | 9 +++--- .../source/adapters/hip/device.cpp | 2 +- .../source/adapters/hip/platform.cpp | 17 +++++----- unified-runtime/source/adapters/hip/usm.cpp | 2 +- 6 files changed, 51 insertions(+), 32 deletions(-) diff --git a/unified-runtime/source/adapters/hip/adapter.cpp b/unified-runtime/source/adapters/hip/adapter.cpp index 9daaee8a29738..7fb3acf411ad9 100644 --- a/unified-runtime/source/adapters/hip/adapter.cpp +++ b/unified-runtime/source/adapters/hip/adapter.cpp @@ -10,16 +10,14 @@ #include "adapter.hpp" #include "common.hpp" -#include "logger/ur_logger.hpp" -#include #include -struct ur_adapter_handle_t_ { - std::atomic RefCount = 0; - logger::Logger &logger; - ur_adapter_handle_t_(); -}; +#include + +namespace ur::hip { +std::shared_ptr adapter; +} class ur_legacy_sink : public logger::Sink { public: @@ -42,7 +40,7 @@ class ur_legacy_sink : public logger::Sink { ur_adapter_handle_t_::ur_adapter_handle_t_() : logger( logger::get_logger("hip", /*default_log_level*/ logger::Level::ERR)) { - + Platform = std::make_unique(); if (std::getenv("UR_LOG_HIP") != nullptr) return; @@ -52,13 +50,16 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() } } -ur_adapter_handle_t_ adapter{}; - UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (phAdapters) { - adapter.RefCount++; - *phAdapters = &adapter; + static std::once_flag InitFlag; + std::call_once(InitFlag, [=]() { + ur::hip::adapter = std::make_shared(); + }); + + ur::hip::adapter->RefCount++; + *phAdapters = ur::hip::adapter.get(); } if (pNumAdapters) { *pNumAdapters = 1; @@ -69,12 +70,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { // No state to clean up so we don't need to check for 0 references - adapter.RefCount--; + ur::hip::adapter->RefCount--; return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - adapter.RefCount++; + ur::hip::adapter->RefCount++; return UR_RESULT_SUCCESS; } @@ -96,7 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_ADAPTER_BACKEND_HIP); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(adapter.RefCount.load()); + return ReturnValue(ur::hip::adapter->RefCount.load()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/hip/adapter.hpp b/unified-runtime/source/adapters/hip/adapter.hpp index 781d8c693c5de..3a6305506eda6 100644 --- a/unified-runtime/source/adapters/hip/adapter.hpp +++ b/unified-runtime/source/adapters/hip/adapter.hpp @@ -8,6 +8,24 @@ // //===----------------------------------------------------------------------===// -struct ur_adapter_handle_t_; +#ifndef UR_HIP_ADAPTER_HPP_INCLUDED +#define UR_HIP_ADAPTER_HPP_INCLUDED -extern ur_adapter_handle_t_ adapter; +#include "logger/ur_logger.hpp" +#include "platform.hpp" + +#include +#include + +struct ur_adapter_handle_t_ { + std::atomic RefCount = 1; + logger::Logger &logger; + std::unique_ptr Platform; + ur_adapter_handle_t_(); +}; + +namespace ur::hip { +extern std::shared_ptr adapter; +} + +#endif // UR_HIP_ADAPTER_HPP_INCLUDED diff --git a/unified-runtime/source/adapters/hip/context.hpp b/unified-runtime/source/adapters/hip/context.hpp index 5af95753b8e32..9207e006ef687 100644 --- a/unified-runtime/source/adapters/hip/context.hpp +++ b/unified-runtime/source/adapters/hip/context.hpp @@ -11,6 +11,7 @@ #include +#include "adapter.hpp" #include "common.hpp" #include "device.hpp" #include "platform.hpp" @@ -85,16 +86,14 @@ struct ur_context_handle_t_ { void operator()() { Function(UserData); } }; + std::shared_ptr Adapter; std::vector Devices; std::atomic_uint32_t RefCount; ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) - : Devices{Devs, Devs + NumDevices}, RefCount{1} { - for (auto &Dev : Devices) { - urDeviceRetain(Dev); - } - }; + : Adapter(ur::hip::adapter), Devices{Devs, Devs + NumDevices}, + RefCount{1} {}; ~ur_context_handle_t_() {} diff --git a/unified-runtime/source/adapters/hip/device.cpp b/unified-runtime/source/adapters/hip/device.cpp index 182727a82d89b..c6718f690201e 100644 --- a/unified-runtime/source/adapters/hip/device.cpp +++ b/unified-runtime/source/adapters/hip/device.cpp @@ -1174,7 +1174,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Get list of platforms uint32_t NumPlatforms = 0; - ur_adapter_handle_t AdapterHandle = &adapter; + ur_adapter_handle_t AdapterHandle = ur::hip::adapter.get(); ur_result_t Result = urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms); if (Result != UR_RESULT_SUCCESS) diff --git a/unified-runtime/source/adapters/hip/platform.cpp b/unified-runtime/source/adapters/hip/platform.cpp index fa0b07cc8244a..69d96904cf811 100644 --- a/unified-runtime/source/adapters/hip/platform.cpp +++ b/unified-runtime/source/adapters/hip/platform.cpp @@ -36,7 +36,7 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName, return ReturnValue(""); } case UR_PLATFORM_INFO_ADAPTER: { - return ReturnValue(&adapter); + return ReturnValue(ur::hip::adapter.get()); } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; @@ -56,7 +56,6 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, try { static std::once_flag InitFlag; static uint32_t NumPlatforms = 1; - static ur_platform_handle_t_ Platform; UR_ASSERT(phPlatforms || pNumPlatforms, UR_RESULT_ERROR_INVALID_VALUE); UR_ASSERT(!phPlatforms || NumEntries > 0, UR_RESULT_ERROR_INVALID_VALUE); @@ -87,18 +86,20 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, // Use the default stream to record base event counter UR_CHECK_ERROR(hipEventRecord(EvBase, 0)); - Platform.Devices.emplace_back( - new ur_device_handle_t_{Device, EvBase, &Platform, i}); + ur::hip::adapter->Platform->Devices.emplace_back( + new ur_device_handle_t_{Device, EvBase, + ur::hip::adapter->Platform.get(), i}); - ScopedDevice Active(Platform.Devices.front().get()); + ScopedDevice Active( + ur::hip::adapter->Platform->Devices.front().get()); } } catch (const std::bad_alloc &) { // Signal out-of-memory situation - Platform.Devices.clear(); + ur::hip::adapter->Platform->Devices.clear(); Err = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; } catch (ur_result_t CatchErr) { // Clear and rethrow to allow retry - Platform.Devices.clear(); + ur::hip::adapter->Platform->Devices.clear(); Err = CatchErr; throw CatchErr; } catch (...) { @@ -113,7 +114,7 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries, } if (phPlatforms != nullptr) { - *phPlatforms = &Platform; + *phPlatforms = ur::hip::adapter->Platform.get(); } return Result; diff --git a/unified-runtime/source/adapters/hip/usm.cpp b/unified-runtime/source/adapters/hip/usm.cpp index 7412d4b1eb8b2..f227ab7e4370f 100644 --- a/unified-runtime/source/adapters/hip/usm.cpp +++ b/unified-runtime/source/adapters/hip/usm.cpp @@ -199,7 +199,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // hip backend has only one platform containing all devices ur_platform_handle_t platform; - ur_adapter_handle_t AdapterHandle = &adapter; + ur_adapter_handle_t AdapterHandle = ur::hip::adapter.get(); UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr)); // get the device from the platform From dc2a5c01b9b6d2264b2c12740b49f856963871e3 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Thu, 3 Apr 2025 16:48:45 +0100 Subject: [PATCH 6/8] Change cuda adapter back to using dumb pointers. --- .../source/adapters/cuda/adapter.cpp | 26 +++++++++---------- .../source/adapters/cuda/adapter.hpp | 6 ++--- .../source/adapters/cuda/context.hpp | 9 +++---- .../source/adapters/cuda/device.cpp | 2 +- .../source/adapters/cuda/platform.cpp | 2 +- unified-runtime/source/adapters/cuda/usm.cpp | 2 +- 6 files changed, 22 insertions(+), 25 deletions(-) diff --git a/unified-runtime/source/adapters/cuda/adapter.cpp b/unified-runtime/source/adapters/cuda/adapter.cpp index d0083995542ea..0e42d6108647c 100644 --- a/unified-runtime/source/adapters/cuda/adapter.cpp +++ b/unified-runtime/source/adapters/cuda/adapter.cpp @@ -16,7 +16,7 @@ #include namespace ur::cuda { -std::shared_ptr adapter; +ur_adapter_handle_t adapter; } // namespace ur::cuda class ur_legacy_sink : public logger::Sink { @@ -47,6 +47,14 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr)) { logger.setLegacySink(std::make_unique()); } + + TracingCtx = createCUDATracingContext(); + enableCUDATracing(TracingCtx); +} + +ur_adapter_handle_t_::~ur_adapter_handle_t_() { + disableCUDATracing(TracingCtx); + freeCUDATracingContext(TracingCtx); } UR_APIEXPORT ur_result_t UR_APICALL @@ -54,15 +62,10 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (NumEntries > 0 && phAdapters) { static std::once_flag InitFlag; - std::call_once(InitFlag, [=]() { - ur::cuda::adapter = std::make_shared(); - }); - - std::lock_guard Lock{ur::cuda::adapter->Mutex}; - ur::cuda::adapter->TracingCtx = createCUDATracingContext(); - enableCUDATracing(ur::cuda::adapter->TracingCtx); + std::call_once(InitFlag, + [=]() { ur::cuda::adapter = new ur_adapter_handle_t_; }); - *phAdapters = ur::cuda::adapter.get(); + *phAdapters = ur::cuda::adapter; } if (pNumAdapters) { @@ -79,11 +82,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - std::lock_guard Lock{ur::cuda::adapter->Mutex}; if (--ur::cuda::adapter->RefCount == 0) { - disableCUDATracing(ur::cuda::adapter->TracingCtx); - freeCUDATracingContext(ur::cuda::adapter->TracingCtx); - ur::cuda::adapter.reset(); + delete ur::cuda::adapter; } return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/cuda/adapter.hpp b/unified-runtime/source/adapters/cuda/adapter.hpp index 4fef3ab99cb2a..873fd7fc99bca 100644 --- a/unified-runtime/source/adapters/cuda/adapter.hpp +++ b/unified-runtime/source/adapters/cuda/adapter.hpp @@ -18,20 +18,20 @@ #include #include -#include struct ur_adapter_handle_t_ { std::atomic RefCount = 1; - std::mutex Mutex; struct cuda_tracing_context_t_ *TracingCtx = nullptr; logger::Logger &logger; std::unique_ptr Platform; ur_adapter_handle_t_(); + ~ur_adapter_handle_t_(); + ur_adapter_handle_t_(const ur_adapter_handle_t_ &) = delete; }; // Keep the global namespace'd namespace ur::cuda { -extern std::shared_ptr adapter; +extern ur_adapter_handle_t adapter; } // namespace ur::cuda #endif // UR_CUDA_ADAPTER_HPP_INCLUDED diff --git a/unified-runtime/source/adapters/cuda/context.hpp b/unified-runtime/source/adapters/cuda/context.hpp index c35a57353e518..d22b2b5442201 100644 --- a/unified-runtime/source/adapters/cuda/context.hpp +++ b/unified-runtime/source/adapters/cuda/context.hpp @@ -119,10 +119,6 @@ struct ur_context_handle_t_ { void operator()() { Function(UserData); } }; - // Retain an additional reference to the adapter as it keeps the devices - // alive, which end up being used (indirectly) by our destructor. - std::shared_ptr Adapter; - std::vector Devices; std::atomic_uint32_t RefCount; @@ -132,13 +128,13 @@ struct ur_context_handle_t_ { umf_memory_pool_handle_t MemoryPoolHost = nullptr; ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) - : Adapter(ur::cuda::adapter), Devices{Devs, Devs + NumDevices}, - RefCount{1} { + : Devices{Devs, Devs + NumDevices}, RefCount{1} { // Create UMF CUDA memory provider for the host memory // (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because // it is guaranteed to exist). UR_CHECK_ERROR(CreateHostMemoryProviderPool(Devices[0], &MemoryProviderHost, &MemoryPoolHost)); + UR_CHECK_ERROR(urAdapterRetain(ur::cuda::adapter)); }; ~ur_context_handle_t_() { @@ -148,6 +144,7 @@ struct ur_context_handle_t_ { if (MemoryProviderHost) { umfMemoryProviderDestroy(MemoryProviderHost); } + UR_CHECK_ERROR(urAdapterRelease(ur::cuda::adapter)); } void invokeExtendedDeleters() { diff --git a/unified-runtime/source/adapters/cuda/device.cpp b/unified-runtime/source/adapters/cuda/device.cpp index cb49e945fca07..95009fe92c01c 100644 --- a/unified-runtime/source/adapters/cuda/device.cpp +++ b/unified-runtime/source/adapters/cuda/device.cpp @@ -1282,7 +1282,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Get list of platforms uint32_t NumPlatforms = 0; - ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get(); + ur_adapter_handle_t AdapterHandle = ur::cuda::adapter; ur_result_t Result = urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms); if (Result != UR_RESULT_SUCCESS) diff --git a/unified-runtime/source/adapters/cuda/platform.cpp b/unified-runtime/source/adapters/cuda/platform.cpp index 6a9a2a691ab37..c2797b34bbc2d 100644 --- a/unified-runtime/source/adapters/cuda/platform.cpp +++ b/unified-runtime/source/adapters/cuda/platform.cpp @@ -100,7 +100,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( return ReturnValue(UR_PLATFORM_BACKEND_CUDA); } case UR_PLATFORM_INFO_ADAPTER: { - return ReturnValue(ur::cuda::adapter.get()); + return ReturnValue(ur::cuda::adapter); } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; diff --git a/unified-runtime/source/adapters/cuda/usm.cpp b/unified-runtime/source/adapters/cuda/usm.cpp index d47d790abed5e..f07c4d2a8f532 100644 --- a/unified-runtime/source/adapters/cuda/usm.cpp +++ b/unified-runtime/source/adapters/cuda/usm.cpp @@ -242,7 +242,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // cuda backend has only one platform containing all devices ur_platform_handle_t platform; - ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get(); + ur_adapter_handle_t AdapterHandle = ur::cuda::adapter; Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr); // get the device from the platform From fe337f8da984a2264a60d0e57a66c9517fb37a2e Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Fri, 4 Apr 2025 11:08:55 +0100 Subject: [PATCH 7/8] Make Hip also use a dumb pointer. --- unified-runtime/source/adapters/hip/adapter.cpp | 16 ++++++++-------- unified-runtime/source/adapters/hip/adapter.hpp | 2 +- unified-runtime/source/adapters/hip/context.hpp | 12 ++++++++---- unified-runtime/source/adapters/hip/device.cpp | 2 +- unified-runtime/source/adapters/hip/platform.cpp | 2 +- unified-runtime/source/adapters/hip/usm.cpp | 2 +- 6 files changed, 20 insertions(+), 16 deletions(-) diff --git a/unified-runtime/source/adapters/hip/adapter.cpp b/unified-runtime/source/adapters/hip/adapter.cpp index 7fb3acf411ad9..decacbe6e8501 100644 --- a/unified-runtime/source/adapters/hip/adapter.cpp +++ b/unified-runtime/source/adapters/hip/adapter.cpp @@ -16,7 +16,7 @@ #include namespace ur::hip { -std::shared_ptr adapter; +ur_adapter_handle_t adapter; } class ur_legacy_sink : public logger::Sink { @@ -54,12 +54,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (phAdapters) { static std::once_flag InitFlag; - std::call_once(InitFlag, [=]() { - ur::hip::adapter = std::make_shared(); - }); + std::call_once(InitFlag, + [=]() { ur::hip::adapter = new ur_adapter_handle_t_; }); - ur::hip::adapter->RefCount++; - *phAdapters = ur::hip::adapter.get(); + *phAdapters = ur::hip::adapter; } if (pNumAdapters) { *pNumAdapters = 1; @@ -69,8 +67,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - // No state to clean up so we don't need to check for 0 references - ur::hip::adapter->RefCount--; + if (--ur::hip::adapter->RefCount == 0) { + delete ur::hip::adapter; + } + return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/hip/adapter.hpp b/unified-runtime/source/adapters/hip/adapter.hpp index 3a6305506eda6..59090e3b6bc60 100644 --- a/unified-runtime/source/adapters/hip/adapter.hpp +++ b/unified-runtime/source/adapters/hip/adapter.hpp @@ -25,7 +25,7 @@ struct ur_adapter_handle_t_ { }; namespace ur::hip { -extern std::shared_ptr adapter; +extern ur_adapter_handle_t adapter; } #endif // UR_HIP_ADAPTER_HPP_INCLUDED diff --git a/unified-runtime/source/adapters/hip/context.hpp b/unified-runtime/source/adapters/hip/context.hpp index 9207e006ef687..1d2b94562622b 100644 --- a/unified-runtime/source/adapters/hip/context.hpp +++ b/unified-runtime/source/adapters/hip/context.hpp @@ -86,16 +86,20 @@ struct ur_context_handle_t_ { void operator()() { Function(UserData); } }; - std::shared_ptr Adapter; std::vector Devices; std::atomic_uint32_t RefCount; ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) - : Adapter(ur::hip::adapter), Devices{Devs, Devs + NumDevices}, - RefCount{1} {}; + : Devices{Devs, Devs + NumDevices}, RefCount{1} { + UR_CHECK_ERROR(urAdapterRetain(ur::hip::adapter)); + }; + + ~ur_context_handle_t_() { + UR_CHECK_ERROR(urAdapterRelease(ur::hip::adapter)); + } - ~ur_context_handle_t_() {} + ur_context_handle_t_(const ur_context_handle_t_ &) = delete; void invokeExtendedDeleters() { std::lock_guard Guard(Mutex); diff --git a/unified-runtime/source/adapters/hip/device.cpp b/unified-runtime/source/adapters/hip/device.cpp index 24dc213564e41..bdb92997fde6e 100644 --- a/unified-runtime/source/adapters/hip/device.cpp +++ b/unified-runtime/source/adapters/hip/device.cpp @@ -1179,7 +1179,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Get list of platforms uint32_t NumPlatforms = 0; - ur_adapter_handle_t AdapterHandle = ur::hip::adapter.get(); + ur_adapter_handle_t AdapterHandle = ur::hip::adapter; ur_result_t Result = urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms); if (Result != UR_RESULT_SUCCESS) diff --git a/unified-runtime/source/adapters/hip/platform.cpp b/unified-runtime/source/adapters/hip/platform.cpp index 69d96904cf811..a75128c40624b 100644 --- a/unified-runtime/source/adapters/hip/platform.cpp +++ b/unified-runtime/source/adapters/hip/platform.cpp @@ -36,7 +36,7 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName, return ReturnValue(""); } case UR_PLATFORM_INFO_ADAPTER: { - return ReturnValue(ur::hip::adapter.get()); + return ReturnValue(ur::hip::adapter); } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; diff --git a/unified-runtime/source/adapters/hip/usm.cpp b/unified-runtime/source/adapters/hip/usm.cpp index 67afc19ce73e1..ee5cfd259f7cb 100644 --- a/unified-runtime/source/adapters/hip/usm.cpp +++ b/unified-runtime/source/adapters/hip/usm.cpp @@ -199,7 +199,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // hip backend has only one platform containing all devices ur_platform_handle_t platform; - ur_adapter_handle_t AdapterHandle = ur::hip::adapter.get(); + ur_adapter_handle_t AdapterHandle = ur::hip::adapter; UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr)); // get the device from the platform From fe3af66c6dc828ef16436109e448c2236387adc5 Mon Sep 17 00:00:00 2001 From: Aaron Greig Date: Wed, 9 Apr 2025 14:57:23 +0100 Subject: [PATCH 8/8] Make urAdapterGet increment RefCount properly. --- unified-runtime/source/adapters/cuda/adapter.cpp | 1 + unified-runtime/source/adapters/cuda/adapter.hpp | 2 +- unified-runtime/source/adapters/hip/adapter.cpp | 1 + unified-runtime/source/adapters/hip/adapter.hpp | 2 +- 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/unified-runtime/source/adapters/cuda/adapter.cpp b/unified-runtime/source/adapters/cuda/adapter.cpp index 0e42d6108647c..4bc622d438323 100644 --- a/unified-runtime/source/adapters/cuda/adapter.cpp +++ b/unified-runtime/source/adapters/cuda/adapter.cpp @@ -65,6 +65,7 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, std::call_once(InitFlag, [=]() { ur::cuda::adapter = new ur_adapter_handle_t_; }); + ur::cuda::adapter->RefCount++; *phAdapters = ur::cuda::adapter; } diff --git a/unified-runtime/source/adapters/cuda/adapter.hpp b/unified-runtime/source/adapters/cuda/adapter.hpp index 873fd7fc99bca..2f6ff157ae9d1 100644 --- a/unified-runtime/source/adapters/cuda/adapter.hpp +++ b/unified-runtime/source/adapters/cuda/adapter.hpp @@ -20,7 +20,7 @@ #include struct ur_adapter_handle_t_ { - std::atomic RefCount = 1; + std::atomic RefCount = 0; struct cuda_tracing_context_t_ *TracingCtx = nullptr; logger::Logger &logger; std::unique_ptr Platform; diff --git a/unified-runtime/source/adapters/hip/adapter.cpp b/unified-runtime/source/adapters/hip/adapter.cpp index decacbe6e8501..414fed4734f6b 100644 --- a/unified-runtime/source/adapters/hip/adapter.cpp +++ b/unified-runtime/source/adapters/hip/adapter.cpp @@ -57,6 +57,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( std::call_once(InitFlag, [=]() { ur::hip::adapter = new ur_adapter_handle_t_; }); + ur::hip::adapter->RefCount++; *phAdapters = ur::hip::adapter; } if (pNumAdapters) { diff --git a/unified-runtime/source/adapters/hip/adapter.hpp b/unified-runtime/source/adapters/hip/adapter.hpp index 59090e3b6bc60..c4a750eee2389 100644 --- a/unified-runtime/source/adapters/hip/adapter.hpp +++ b/unified-runtime/source/adapters/hip/adapter.hpp @@ -18,7 +18,7 @@ #include struct ur_adapter_handle_t_ { - std::atomic RefCount = 1; + std::atomic RefCount = 0; logger::Logger &logger; std::unique_ptr Platform; ur_adapter_handle_t_();