Skip to content

Commit f3d6211

Browse files
committed
[UR][CUDA] Allow cuda adapter objects to block full adapter teardown.
In the cuda adapter the adapter struct itself is currently an extern global defined in adapter.cpp. This means fully tearing down the adapter is subject to the same destructor ordering as all other static and global variables, it's first in last out. This presents a problem because an application can declare a static sycl object like a buffer right up top before doing anything else, which results in the sycl object being destroyed after the cuda adapter struct. The UR spec doesn't put the onus on users to keep their parent object lifetimes in order, i.e. there is no statement about "the context you use to create a ur_mem_handle_t must not be released until after the mem_handle". It's assumed (by omission rather than explicitly) that adapters will have their objects keep a reference to any parent objects alive for the duration of their own lifetime. This change moves the cuda adapter structs ownership into a global shared_ptr, which allows child objects of the adapter to keep their own references to it alive past the point where its initial definition goes out of scope. Also adjusts how some other objects track parent object references so that the destructors correctly cascade back to the top: mem handle releases its context, which releases its adapter, which releases the platform + devices, etc. Fixes #17450
1 parent 08a74f9 commit f3d6211

File tree

10 files changed

+79
-53
lines changed

10 files changed

+79
-53
lines changed

sycl/test-e2e/Regression/static-buffer-dtor.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
// UNSUPPORTED: windows && arch-intel_gpu_bmg_g21
2222
// UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17255
2323

24-
// UNSUPPORTED: cuda
25-
// UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17450
26-
2724
#include <sycl/detail/core.hpp>
2825

2926
int main() {

unified-runtime/source/adapters/cuda/adapter.cpp

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,16 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11-
#include <ur_api.h>
12-
11+
#include "adapter.hpp"
1312
#include "common.hpp"
14-
#include "logger/ur_logger.hpp"
13+
#include "platform.hpp"
1514
#include "tracing.hpp"
1615

17-
struct ur_adapter_handle_t_ {
18-
std::atomic<uint32_t> RefCount = 0;
19-
std::mutex Mutex;
20-
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
21-
logger::Logger &logger;
22-
ur_adapter_handle_t_();
23-
};
16+
#include <memory>
17+
18+
namespace ur::cuda {
19+
std::shared_ptr<ur_adapter_handle_t_> adapter;
20+
} // namespace ur::cuda
2421

2522
class ur_legacy_sink : public logger::Sink {
2623
public:
@@ -43,28 +40,29 @@ class ur_legacy_sink : public logger::Sink {
4340
ur_adapter_handle_t_::ur_adapter_handle_t_()
4441
: logger(logger::get_logger("cuda",
4542
/*default_log_level*/ logger::Level::ERR)) {
43+
Platform = std::make_unique<ur_platform_handle_t_>();
4644

47-
if (std::getenv("UR_LOG_CUDA") != nullptr)
48-
return;
49-
50-
if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr ||
51-
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr) {
45+
if (std::getenv("UR_LOG_CUDA") == nullptr &&
46+
(std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr ||
47+
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr)) {
5248
logger.setLegacySink(std::make_unique<ur_legacy_sink>());
5349
}
5450
}
55-
ur_adapter_handle_t_ adapter{};
5651

5752
UR_APIEXPORT ur_result_t UR_APICALL
5853
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
5954
uint32_t *pNumAdapters) {
6055
if (NumEntries > 0 && phAdapters) {
61-
std::lock_guard<std::mutex> Lock{adapter.Mutex};
62-
if (adapter.RefCount++ == 0) {
63-
adapter.TracingCtx = createCUDATracingContext();
64-
enableCUDATracing(adapter.TracingCtx);
65-
}
56+
static std::once_flag InitFlag;
57+
std::call_once(InitFlag, [=]() {
58+
ur::cuda::adapter = std::make_shared<ur_adapter_handle_t_>();
59+
});
60+
61+
std::lock_guard<std::mutex> Lock{ur::cuda::adapter->Mutex};
62+
ur::cuda::adapter->TracingCtx = createCUDATracingContext();
63+
enableCUDATracing(ur::cuda::adapter->TracingCtx);
6664

67-
*phAdapters = &adapter;
65+
*phAdapters = ur::cuda::adapter.get();
6866
}
6967

7068
if (pNumAdapters) {
@@ -75,17 +73,17 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
7573
}
7674

7775
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
78-
adapter.RefCount++;
76+
ur::cuda::adapter->RefCount++;
7977

8078
return UR_RESULT_SUCCESS;
8179
}
8280

8381
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
84-
std::lock_guard<std::mutex> Lock{adapter.Mutex};
85-
if (--adapter.RefCount == 0) {
86-
disableCUDATracing(adapter.TracingCtx);
87-
freeCUDATracingContext(adapter.TracingCtx);
88-
adapter.TracingCtx = nullptr;
82+
std::lock_guard<std::mutex> Lock{ur::cuda::adapter->Mutex};
83+
if (--ur::cuda::adapter->RefCount == 0) {
84+
disableCUDATracing(ur::cuda::adapter->TracingCtx);
85+
freeCUDATracingContext(ur::cuda::adapter->TracingCtx);
86+
ur::cuda::adapter.reset();
8987
}
9088
return UR_RESULT_SUCCESS;
9189
}
@@ -108,7 +106,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
108106
case UR_ADAPTER_INFO_BACKEND:
109107
return ReturnValue(UR_ADAPTER_BACKEND_CUDA);
110108
case UR_ADAPTER_INFO_REFERENCE_COUNT:
111-
return ReturnValue(adapter.RefCount.load());
109+
return ReturnValue(ur::cuda::adapter->RefCount.load());
112110
case UR_ADAPTER_INFO_VERSION:
113111
return ReturnValue(uint32_t{1});
114112
default:

unified-runtime/source/adapters/cuda/adapter.hpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,30 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11-
struct ur_adapter_handle_t_;
11+
#include "logger/ur_logger.hpp"
12+
#include "platform.hpp"
13+
#include "tracing.hpp"
14+
#include <ur_api.h>
1215

13-
extern ur_adapter_handle_t_ adapter;
16+
#include <atomic>
17+
#include <memory>
18+
#include <mutex>
19+
20+
// should maybe be an ifdef
21+
#pragma once
22+
23+
struct ur_platform_handle_t_;
24+
25+
struct ur_adapter_handle_t_ {
26+
std::atomic<uint32_t> RefCount = 0;
27+
std::mutex Mutex;
28+
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
29+
logger::Logger &logger;
30+
std::unique_ptr<ur_platform_handle_t_> Platform;
31+
ur_adapter_handle_t_();
32+
};
33+
34+
// Keep the global namespace'd
35+
namespace ur::cuda {
36+
extern std::shared_ptr<ur_adapter_handle_t_> adapter;
37+
} // namespace ur::cuda

unified-runtime/source/adapters/cuda/context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "context.hpp"
12+
#include "platform.hpp"
1213
#include "usm.hpp"
1314

1415
#include <cassert>

unified-runtime/source/adapters/cuda/context.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
#pragma once
1111

1212
#include <cuda.h>
13+
#include <memory>
1314
#include <ur_api.h>
1415

1516
#include <atomic>
1617
#include <mutex>
1718
#include <set>
1819
#include <vector>
1920

21+
#include "adapter.hpp"
2022
#include "common.hpp"
2123
#include "device.hpp"
2224
#include "umf_helpers.hpp"
@@ -117,6 +119,10 @@ struct ur_context_handle_t_ {
117119
void operator()() { Function(UserData); }
118120
};
119121

122+
// Retain an additional reference to the adapter as it keeps the devices
123+
// alive, which end up being used (indirectly) by our destructor.
124+
std::shared_ptr<ur_adapter_handle_t_> Adapter;
125+
120126
std::vector<ur_device_handle_t> Devices;
121127
std::atomic_uint32_t RefCount;
122128

@@ -126,11 +132,8 @@ struct ur_context_handle_t_ {
126132
umf_memory_pool_handle_t MemoryPoolHost = nullptr;
127133

128134
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
129-
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
130-
for (auto &Dev : Devices) {
131-
urDeviceRetain(Dev);
132-
}
133-
135+
: Adapter(ur::cuda::adapter), Devices{Devs, Devs + NumDevices},
136+
RefCount{1} {
134137
// Create UMF CUDA memory provider for the host memory
135138
// (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
136139
// it is guaranteed to exist).
@@ -145,9 +148,6 @@ struct ur_context_handle_t_ {
145148
if (MemoryProviderHost) {
146149
umfMemoryProviderDestroy(MemoryProviderHost);
147150
}
148-
for (auto &Dev : Devices) {
149-
urDeviceRelease(Dev);
150-
}
151151
}
152152

153153
void invokeExtendedDeleters() {

unified-runtime/source/adapters/cuda/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
12161216

12171217
// Get list of platforms
12181218
uint32_t NumPlatforms = 0;
1219-
ur_adapter_handle_t AdapterHandle = &adapter;
1219+
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get();
12201220
ur_result_t Result =
12211221
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
12221222
if (Result != UR_RESULT_SUCCESS)

unified-runtime/source/adapters/cuda/memory.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ struct ur_mem_handle_t_ {
393393
urMemRelease(std::get<BufferMem>(Mem).Parent);
394394
return;
395395
}
396+
if (LastQueueWritingToMemObj != nullptr) {
397+
urQueueRelease(LastQueueWritingToMemObj);
398+
}
396399
urContextRelease(Context);
397400
}
398401

unified-runtime/source/adapters/cuda/platform.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
100100
return ReturnValue(UR_PLATFORM_BACKEND_CUDA);
101101
}
102102
case UR_PLATFORM_INFO_ADAPTER: {
103-
return ReturnValue(&adapter);
103+
return ReturnValue(&ur::cuda::adapter);
104104
}
105105
default:
106106
return UR_RESULT_ERROR_INVALID_ENUMERATION;
@@ -116,11 +116,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
116116
UR_APIEXPORT ur_result_t UR_APICALL
117117
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
118118
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {
119-
120119
try {
121120
static std::once_flag InitFlag;
122121
static uint32_t NumPlatforms = 1;
123-
static ur_platform_handle_t_ Platform;
124122

125123
UR_ASSERT(phPlatforms || pNumPlatforms, UR_RESULT_ERROR_INVALID_VALUE);
126124
UR_ASSERT(!phPlatforms || NumEntries > 0, UR_RESULT_ERROR_INVALID_SIZE);
@@ -151,22 +149,24 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
151149
// Use default stream to record base event counter
152150
UR_CHECK_ERROR(cuEventRecord(EvBase, 0));
153151

154-
Platform.Devices.emplace_back(
155-
new ur_device_handle_t_{Device, Context, EvBase, &Platform,
152+
ur::cuda::adapter->Platform->Devices.emplace_back(
153+
new ur_device_handle_t_{Device, Context, EvBase,
154+
ur::cuda::adapter->Platform.get(),
156155
static_cast<uint32_t>(i)});
157156
}
158157

159-
UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(&Platform));
158+
UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(
159+
ur::cuda::adapter->Platform.get()));
160160
} catch (const std::bad_alloc &) {
161161
// Signal out-of-memory situation
162162
for (int i = 0; i < NumDevices; ++i) {
163-
Platform.Devices.clear();
163+
ur::cuda::adapter->Platform->Devices.clear();
164164
}
165165
Result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
166166
} catch (ur_result_t Err) {
167167
// Clear and rethrow to allow retry
168168
for (int i = 0; i < NumDevices; ++i) {
169-
Platform.Devices.clear();
169+
ur::cuda::adapter->Platform->Devices.clear();
170170
}
171171
Result = Err;
172172
throw Err;
@@ -182,7 +182,7 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
182182
}
183183

184184
if (phPlatforms != nullptr) {
185-
*phPlatforms = &Platform;
185+
*phPlatforms = ur::cuda::adapter->Platform.get();
186186
}
187187

188188
return Result;

unified-runtime/source/adapters/cuda/platform.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
//===----------------------------------------------------------------------===//
1010
#pragma once
1111

12+
#include "device.hpp"
1213
#include <ur/ur.hpp>
14+
15+
#include <memory>
1316
#include <vector>
1417

1518
struct ur_platform_handle_t_ {

unified-runtime/source/adapters/cuda/usm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
242242

243243
// cuda backend has only one platform containing all devices
244244
ur_platform_handle_t platform;
245-
ur_adapter_handle_t AdapterHandle = &adapter;
245+
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get();
246246
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);
247247

248248
// get the device from the platform

0 commit comments

Comments
 (0)