diff --git a/sycl/test-e2e/Regression/static-buffer-dtor.cpp b/sycl/test-e2e/Regression/static-buffer-dtor.cpp index 8cc1ddce43501..474c2b756d3fe 100644 --- a/sycl/test-e2e/Regression/static-buffer-dtor.cpp +++ b/sycl/test-e2e/Regression/static-buffer-dtor.cpp @@ -21,9 +21,6 @@ // UNSUPPORTED: windows && arch-intel_gpu_bmg_g21 // UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17255 -// UNSUPPORTED: cuda -// UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17450 - #include int main() { diff --git a/unified-runtime/source/adapters/cuda/adapter.cpp b/unified-runtime/source/adapters/cuda/adapter.cpp index 3ea896bbd6cc3..4bc622d438323 100644 --- a/unified-runtime/source/adapters/cuda/adapter.cpp +++ b/unified-runtime/source/adapters/cuda/adapter.cpp @@ -8,19 +8,16 @@ // //===----------------------------------------------------------------------===// -#include - +#include "adapter.hpp" #include "common.hpp" -#include "logger/ur_logger.hpp" +#include "platform.hpp" #include "tracing.hpp" -struct ur_adapter_handle_t_ { - std::atomic RefCount = 0; - std::mutex Mutex; - struct cuda_tracing_context_t_ *TracingCtx = nullptr; - logger::Logger &logger; - ur_adapter_handle_t_(); -}; +#include + +namespace ur::cuda { +ur_adapter_handle_t adapter; +} // namespace ur::cuda class ur_legacy_sink : public logger::Sink { public: @@ -43,28 +40,33 @@ class ur_legacy_sink : public logger::Sink { ur_adapter_handle_t_::ur_adapter_handle_t_() : logger(logger::get_logger("cuda", /*default_log_level*/ logger::Level::ERR)) { + Platform = std::make_unique(); - if (std::getenv("UR_LOG_CUDA") != nullptr) - return; - - if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr || - std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr) { + if (std::getenv("UR_LOG_CUDA") == nullptr && + (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr || + std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr)) { logger.setLegacySink(std::make_unique()); } + + TracingCtx = createCUDATracingContext(); + enableCUDATracing(TracingCtx); +} + +ur_adapter_handle_t_::~ur_adapter_handle_t_() { + disableCUDATracing(TracingCtx); + freeCUDATracingContext(TracingCtx); } -ur_adapter_handle_t_ adapter{}; UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (NumEntries > 0 && phAdapters) { - std::lock_guard Lock{adapter.Mutex}; - if (adapter.RefCount++ == 0) { - adapter.TracingCtx = createCUDATracingContext(); - enableCUDATracing(adapter.TracingCtx); - } + static std::once_flag InitFlag; + std::call_once(InitFlag, + [=]() { ur::cuda::adapter = new ur_adapter_handle_t_; }); - *phAdapters = &adapter; + ur::cuda::adapter->RefCount++; + *phAdapters = ur::cuda::adapter; } if (pNumAdapters) { @@ -75,17 +77,14 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - adapter.RefCount++; + ur::cuda::adapter->RefCount++; return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - std::lock_guard Lock{adapter.Mutex}; - if (--adapter.RefCount == 0) { - disableCUDATracing(adapter.TracingCtx); - freeCUDATracingContext(adapter.TracingCtx); - adapter.TracingCtx = nullptr; + if (--ur::cuda::adapter->RefCount == 0) { + delete ur::cuda::adapter; } return UR_RESULT_SUCCESS; } @@ -108,7 +107,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_ADAPTER_BACKEND_CUDA); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(adapter.RefCount.load()); + return ReturnValue(ur::cuda::adapter->RefCount.load()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/cuda/adapter.hpp b/unified-runtime/source/adapters/cuda/adapter.hpp index ba05b86502347..2f6ff157ae9d1 100644 --- a/unified-runtime/source/adapters/cuda/adapter.hpp +++ b/unified-runtime/source/adapters/cuda/adapter.hpp @@ -8,6 +8,30 @@ // //===----------------------------------------------------------------------===// -struct ur_adapter_handle_t_; +#ifndef UR_CUDA_ADAPTER_HPP_INCLUDED +#define UR_CUDA_ADAPTER_HPP_INCLUDED -extern ur_adapter_handle_t_ adapter; +#include "logger/ur_logger.hpp" +#include "platform.hpp" +#include "tracing.hpp" +#include + +#include +#include + +struct ur_adapter_handle_t_ { + std::atomic RefCount = 0; + struct cuda_tracing_context_t_ *TracingCtx = nullptr; + logger::Logger &logger; + std::unique_ptr Platform; + ur_adapter_handle_t_(); + ~ur_adapter_handle_t_(); + ur_adapter_handle_t_(const ur_adapter_handle_t_ &) = delete; +}; + +// Keep the global namespace'd +namespace ur::cuda { +extern ur_adapter_handle_t adapter; +} // namespace ur::cuda + +#endif // UR_CUDA_ADAPTER_HPP_INCLUDED diff --git a/unified-runtime/source/adapters/cuda/context.cpp b/unified-runtime/source/adapters/cuda/context.cpp index 8d1adc900da50..4e4f8720f7124 100644 --- a/unified-runtime/source/adapters/cuda/context.cpp +++ b/unified-runtime/source/adapters/cuda/context.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "context.hpp" +#include "platform.hpp" #include "usm.hpp" #include diff --git a/unified-runtime/source/adapters/cuda/context.hpp b/unified-runtime/source/adapters/cuda/context.hpp index b3b40123d2ea8..d22b2b5442201 100644 --- a/unified-runtime/source/adapters/cuda/context.hpp +++ b/unified-runtime/source/adapters/cuda/context.hpp @@ -10,6 +10,7 @@ #pragma once #include +#include #include #include @@ -17,6 +18,7 @@ #include #include +#include "adapter.hpp" #include "common.hpp" #include "device.hpp" #include "umf_helpers.hpp" @@ -127,15 +129,12 @@ struct ur_context_handle_t_ { ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) : Devices{Devs, Devs + NumDevices}, RefCount{1} { - for (auto &Dev : Devices) { - urDeviceRetain(Dev); - } - // Create UMF CUDA memory provider for the host memory // (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because // it is guaranteed to exist). UR_CHECK_ERROR(CreateHostMemoryProviderPool(Devices[0], &MemoryProviderHost, &MemoryPoolHost)); + UR_CHECK_ERROR(urAdapterRetain(ur::cuda::adapter)); }; ~ur_context_handle_t_() { @@ -145,9 +144,7 @@ struct ur_context_handle_t_ { if (MemoryProviderHost) { umfMemoryProviderDestroy(MemoryProviderHost); } - for (auto &Dev : Devices) { - urDeviceRelease(Dev); - } + UR_CHECK_ERROR(urAdapterRelease(ur::cuda::adapter)); } void invokeExtendedDeleters() { diff --git a/unified-runtime/source/adapters/cuda/device.cpp b/unified-runtime/source/adapters/cuda/device.cpp index aa1c0206b2d47..f1eb41eedfc26 100644 --- a/unified-runtime/source/adapters/cuda/device.cpp +++ b/unified-runtime/source/adapters/cuda/device.cpp @@ -1284,7 +1284,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Get list of platforms uint32_t NumPlatforms = 0; - ur_adapter_handle_t AdapterHandle = &adapter; + ur_adapter_handle_t AdapterHandle = ur::cuda::adapter; ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms); if (Result != UR_RESULT_SUCCESS) return Result; diff --git a/unified-runtime/source/adapters/cuda/memory.hpp b/unified-runtime/source/adapters/cuda/memory.hpp index f0fc14f864796..65f065b1ec3c4 100644 --- a/unified-runtime/source/adapters/cuda/memory.hpp +++ b/unified-runtime/source/adapters/cuda/memory.hpp @@ -393,6 +393,9 @@ struct ur_mem_handle_t_ { urMemRelease(std::get(Mem).Parent); return; } + if (LastQueueWritingToMemObj != nullptr) { + urQueueRelease(LastQueueWritingToMemObj); + } urContextRelease(Context); } diff --git a/unified-runtime/source/adapters/cuda/platform.cpp b/unified-runtime/source/adapters/cuda/platform.cpp index 47aa873ca70f8..54a604fd04d85 100644 --- a/unified-runtime/source/adapters/cuda/platform.cpp +++ b/unified-runtime/source/adapters/cuda/platform.cpp @@ -100,7 +100,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( return ReturnValue(UR_PLATFORM_BACKEND_CUDA); } case UR_PLATFORM_INFO_ADAPTER: { - return ReturnValue(&adapter); + return ReturnValue(ur::cuda::adapter); } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; @@ -116,11 +116,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo( UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) { - try { static std::once_flag InitFlag; static uint32_t NumPlatforms = 1; - static ur_platform_handle_t_ Platform; UR_ASSERT(phPlatforms || pNumPlatforms, UR_RESULT_ERROR_INVALID_VALUE); UR_ASSERT(!phPlatforms || NumEntries > 0, UR_RESULT_ERROR_INVALID_SIZE); @@ -151,22 +149,24 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, // Use default stream to record base event counter UR_CHECK_ERROR(cuEventRecord(EvBase, 0)); - Platform.Devices.emplace_back( - new ur_device_handle_t_{Device, Context, EvBase, &Platform, + ur::cuda::adapter->Platform->Devices.emplace_back( + new ur_device_handle_t_{Device, Context, EvBase, + ur::cuda::adapter->Platform.get(), static_cast(i)}); } - UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(&Platform)); + UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools( + ur::cuda::adapter->Platform.get())); } catch (const std::bad_alloc &) { // Signal out-of-memory situation for (int i = 0; i < NumDevices; ++i) { - Platform.Devices.clear(); + ur::cuda::adapter->Platform->Devices.clear(); } Result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; } catch (ur_result_t Err) { // Clear and rethrow to allow retry for (int i = 0; i < NumDevices; ++i) { - Platform.Devices.clear(); + ur::cuda::adapter->Platform->Devices.clear(); } Result = Err; throw Err; @@ -182,7 +182,7 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, } if (phPlatforms != nullptr) { - *phPlatforms = &Platform; + *phPlatforms = ur::cuda::adapter->Platform.get(); } return Result; diff --git a/unified-runtime/source/adapters/cuda/platform.hpp b/unified-runtime/source/adapters/cuda/platform.hpp index 5da72057abe12..8ecc19c3e9f61 100644 --- a/unified-runtime/source/adapters/cuda/platform.hpp +++ b/unified-runtime/source/adapters/cuda/platform.hpp @@ -7,11 +7,18 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#pragma once +#ifndef UR_CUDA_PLATFORM_HPP_INCLUDED +#define UR_CUDA_PLATFORM_HPP_INCLUDED + +#include "device.hpp" #include + +#include #include struct ur_platform_handle_t_ { std::vector> Devices; }; + +#endif // UR_CUDA_PLATFORM_HPP_INCLUDED diff --git a/unified-runtime/source/adapters/cuda/usm.cpp b/unified-runtime/source/adapters/cuda/usm.cpp index 0b300ed97221e..2c1b81f0527bf 100644 --- a/unified-runtime/source/adapters/cuda/usm.cpp +++ b/unified-runtime/source/adapters/cuda/usm.cpp @@ -242,7 +242,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // cuda backend has only one platform containing all devices ur_platform_handle_t platform; - ur_adapter_handle_t AdapterHandle = &adapter; + ur_adapter_handle_t AdapterHandle = ur::cuda::adapter; Result = urPlatformGet(AdapterHandle, 1, &platform, nullptr); // get the device from the platform diff --git a/unified-runtime/source/adapters/hip/adapter.cpp b/unified-runtime/source/adapters/hip/adapter.cpp index 9daaee8a29738..414fed4734f6b 100644 --- a/unified-runtime/source/adapters/hip/adapter.cpp +++ b/unified-runtime/source/adapters/hip/adapter.cpp @@ -10,16 +10,14 @@ #include "adapter.hpp" #include "common.hpp" -#include "logger/ur_logger.hpp" -#include #include -struct ur_adapter_handle_t_ { - std::atomic RefCount = 0; - logger::Logger &logger; - ur_adapter_handle_t_(); -}; +#include + +namespace ur::hip { +ur_adapter_handle_t adapter; +} class ur_legacy_sink : public logger::Sink { public: @@ -42,7 +40,7 @@ class ur_legacy_sink : public logger::Sink { ur_adapter_handle_t_::ur_adapter_handle_t_() : logger( logger::get_logger("hip", /*default_log_level*/ logger::Level::ERR)) { - + Platform = std::make_unique(); if (std::getenv("UR_LOG_HIP") != nullptr) return; @@ -52,13 +50,15 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() } } -ur_adapter_handle_t_ adapter{}; - UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (phAdapters) { - adapter.RefCount++; - *phAdapters = &adapter; + static std::once_flag InitFlag; + std::call_once(InitFlag, + [=]() { ur::hip::adapter = new ur_adapter_handle_t_; }); + + ur::hip::adapter->RefCount++; + *phAdapters = ur::hip::adapter; } if (pNumAdapters) { *pNumAdapters = 1; @@ -68,13 +68,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - // No state to clean up so we don't need to check for 0 references - adapter.RefCount--; + if (--ur::hip::adapter->RefCount == 0) { + delete ur::hip::adapter; + } + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - adapter.RefCount++; + ur::hip::adapter->RefCount++; return UR_RESULT_SUCCESS; } @@ -96,7 +98,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_ADAPTER_BACKEND_HIP); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(adapter.RefCount.load()); + return ReturnValue(ur::hip::adapter->RefCount.load()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/hip/adapter.hpp b/unified-runtime/source/adapters/hip/adapter.hpp index 781d8c693c5de..c4a750eee2389 100644 --- a/unified-runtime/source/adapters/hip/adapter.hpp +++ b/unified-runtime/source/adapters/hip/adapter.hpp @@ -8,6 +8,24 @@ // //===----------------------------------------------------------------------===// -struct ur_adapter_handle_t_; +#ifndef UR_HIP_ADAPTER_HPP_INCLUDED +#define UR_HIP_ADAPTER_HPP_INCLUDED -extern ur_adapter_handle_t_ adapter; +#include "logger/ur_logger.hpp" +#include "platform.hpp" + +#include +#include + +struct ur_adapter_handle_t_ { + std::atomic RefCount = 0; + logger::Logger &logger; + std::unique_ptr Platform; + ur_adapter_handle_t_(); +}; + +namespace ur::hip { +extern ur_adapter_handle_t adapter; +} + +#endif // UR_HIP_ADAPTER_HPP_INCLUDED diff --git a/unified-runtime/source/adapters/hip/context.hpp b/unified-runtime/source/adapters/hip/context.hpp index 5af95753b8e32..1d2b94562622b 100644 --- a/unified-runtime/source/adapters/hip/context.hpp +++ b/unified-runtime/source/adapters/hip/context.hpp @@ -11,6 +11,7 @@ #include +#include "adapter.hpp" #include "common.hpp" #include "device.hpp" #include "platform.hpp" @@ -91,12 +92,14 @@ struct ur_context_handle_t_ { ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices) : Devices{Devs, Devs + NumDevices}, RefCount{1} { - for (auto &Dev : Devices) { - urDeviceRetain(Dev); - } + UR_CHECK_ERROR(urAdapterRetain(ur::hip::adapter)); }; - ~ur_context_handle_t_() {} + ~ur_context_handle_t_() { + UR_CHECK_ERROR(urAdapterRelease(ur::hip::adapter)); + } + + ur_context_handle_t_(const ur_context_handle_t_ &) = delete; void invokeExtendedDeleters() { std::lock_guard Guard(Mutex); diff --git a/unified-runtime/source/adapters/hip/device.cpp b/unified-runtime/source/adapters/hip/device.cpp index d53805d206289..7ad9b242050bb 100644 --- a/unified-runtime/source/adapters/hip/device.cpp +++ b/unified-runtime/source/adapters/hip/device.cpp @@ -1181,7 +1181,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Get list of platforms uint32_t NumPlatforms = 0; - ur_adapter_handle_t AdapterHandle = &adapter; + ur_adapter_handle_t AdapterHandle = ur::hip::adapter; ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms); if (Result != UR_RESULT_SUCCESS) return Result; diff --git a/unified-runtime/source/adapters/hip/platform.cpp b/unified-runtime/source/adapters/hip/platform.cpp index 8fc44ec4b3858..64b56ca47fba2 100644 --- a/unified-runtime/source/adapters/hip/platform.cpp +++ b/unified-runtime/source/adapters/hip/platform.cpp @@ -36,7 +36,7 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName, return ReturnValue(""); } case UR_PLATFORM_INFO_ADAPTER: { - return ReturnValue(&adapter); + return ReturnValue(ur::hip::adapter); } default: return UR_RESULT_ERROR_INVALID_ENUMERATION; @@ -56,7 +56,6 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, try { static std::once_flag InitFlag; static uint32_t NumPlatforms = 1; - static ur_platform_handle_t_ Platform; UR_ASSERT(phPlatforms || pNumPlatforms, UR_RESULT_ERROR_INVALID_VALUE); UR_ASSERT(!phPlatforms || NumEntries > 0, UR_RESULT_ERROR_INVALID_VALUE); @@ -87,18 +86,20 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, // Use the default stream to record base event counter UR_CHECK_ERROR(hipEventRecord(EvBase, 0)); - Platform.Devices.emplace_back( - new ur_device_handle_t_{Device, EvBase, &Platform, i}); + ur::hip::adapter->Platform->Devices.emplace_back( + new ur_device_handle_t_{Device, EvBase, + ur::hip::adapter->Platform.get(), i}); - ScopedDevice Active(Platform.Devices.front().get()); + ScopedDevice Active( + ur::hip::adapter->Platform->Devices.front().get()); } } catch (const std::bad_alloc &) { // Signal out-of-memory situation - Platform.Devices.clear(); + ur::hip::adapter->Platform->Devices.clear(); Err = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; } catch (ur_result_t CatchErr) { // Clear and rethrow to allow retry - Platform.Devices.clear(); + ur::hip::adapter->Platform->Devices.clear(); Err = CatchErr; throw CatchErr; } catch (...) { @@ -113,7 +114,7 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries, } if (phPlatforms != nullptr) { - *phPlatforms = &Platform; + *phPlatforms = ur::hip::adapter->Platform.get(); } return Result; diff --git a/unified-runtime/source/adapters/hip/usm.cpp b/unified-runtime/source/adapters/hip/usm.cpp index 34a2f6774744f..e926ad426856e 100644 --- a/unified-runtime/source/adapters/hip/usm.cpp +++ b/unified-runtime/source/adapters/hip/usm.cpp @@ -199,7 +199,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem, // hip backend has only one platform containing all devices ur_platform_handle_t platform; - ur_adapter_handle_t AdapterHandle = &adapter; + ur_adapter_handle_t AdapterHandle = ur::hip::adapter; UR_CHECK_ERROR(urPlatformGet(AdapterHandle, 1, &platform, nullptr)); // get the device from the platform