Skip to content

Commit dc2a5c0

Browse files
committed
Change cuda adapter back to using dumb pointers.
1 parent 95b8620 commit dc2a5c0

File tree

6 files changed

+22
-25
lines changed

6 files changed

+22
-25
lines changed

unified-runtime/source/adapters/cuda/adapter.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include <memory>
1717

1818
namespace ur::cuda {
19-
std::shared_ptr<ur_adapter_handle_t_> adapter;
19+
ur_adapter_handle_t adapter;
2020
} // namespace ur::cuda
2121

2222
class ur_legacy_sink : public logger::Sink {
@@ -47,22 +47,25 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
4747
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr)) {
4848
logger.setLegacySink(std::make_unique<ur_legacy_sink>());
4949
}
50+
51+
TracingCtx = createCUDATracingContext();
52+
enableCUDATracing(TracingCtx);
53+
}
54+
55+
ur_adapter_handle_t_::~ur_adapter_handle_t_() {
56+
disableCUDATracing(TracingCtx);
57+
freeCUDATracingContext(TracingCtx);
5058
}
5159

5260
UR_APIEXPORT ur_result_t UR_APICALL
5361
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
5462
uint32_t *pNumAdapters) {
5563
if (NumEntries > 0 && phAdapters) {
5664
static std::once_flag InitFlag;
57-
std::call_once(InitFlag, [=]() {
58-
ur::cuda::adapter = std::make_shared<ur_adapter_handle_t_>();
59-
});
60-
61-
std::lock_guard<std::mutex> Lock{ur::cuda::adapter->Mutex};
62-
ur::cuda::adapter->TracingCtx = createCUDATracingContext();
63-
enableCUDATracing(ur::cuda::adapter->TracingCtx);
65+
std::call_once(InitFlag,
66+
[=]() { ur::cuda::adapter = new ur_adapter_handle_t_; });
6467

65-
*phAdapters = ur::cuda::adapter.get();
68+
*phAdapters = ur::cuda::adapter;
6669
}
6770

6871
if (pNumAdapters) {
@@ -79,11 +82,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
7982
}
8083

8184
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
82-
std::lock_guard<std::mutex> Lock{ur::cuda::adapter->Mutex};
8385
if (--ur::cuda::adapter->RefCount == 0) {
84-
disableCUDATracing(ur::cuda::adapter->TracingCtx);
85-
freeCUDATracingContext(ur::cuda::adapter->TracingCtx);
86-
ur::cuda::adapter.reset();
86+
delete ur::cuda::adapter;
8787
}
8888
return UR_RESULT_SUCCESS;
8989
}

unified-runtime/source/adapters/cuda/adapter.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,20 @@
1818

1919
#include <atomic>
2020
#include <memory>
21-
#include <mutex>
2221

2322
struct ur_adapter_handle_t_ {
2423
std::atomic<uint32_t> RefCount = 1;
25-
std::mutex Mutex;
2624
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
2725
logger::Logger &logger;
2826
std::unique_ptr<ur_platform_handle_t_> Platform;
2927
ur_adapter_handle_t_();
28+
~ur_adapter_handle_t_();
29+
ur_adapter_handle_t_(const ur_adapter_handle_t_ &) = delete;
3030
};
3131

3232
// Keep the global namespace'd
3333
namespace ur::cuda {
34-
extern std::shared_ptr<ur_adapter_handle_t_> adapter;
34+
extern ur_adapter_handle_t adapter;
3535
} // namespace ur::cuda
3636

3737
#endif // UR_CUDA_ADAPTER_HPP_INCLUDED

unified-runtime/source/adapters/cuda/context.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,6 @@ struct ur_context_handle_t_ {
119119
void operator()() { Function(UserData); }
120120
};
121121

122-
// Retain an additional reference to the adapter as it keeps the devices
123-
// alive, which end up being used (indirectly) by our destructor.
124-
std::shared_ptr<ur_adapter_handle_t_> Adapter;
125-
126122
std::vector<ur_device_handle_t> Devices;
127123
std::atomic_uint32_t RefCount;
128124

@@ -132,13 +128,13 @@ struct ur_context_handle_t_ {
132128
umf_memory_pool_handle_t MemoryPoolHost = nullptr;
133129

134130
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
135-
: Adapter(ur::cuda::adapter), Devices{Devs, Devs + NumDevices},
136-
RefCount{1} {
131+
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
137132
// Create UMF CUDA memory provider for the host memory
138133
// (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
139134
// it is guaranteed to exist).
140135
UR_CHECK_ERROR(CreateHostMemoryProviderPool(Devices[0], &MemoryProviderHost,
141136
&MemoryPoolHost));
137+
UR_CHECK_ERROR(urAdapterRetain(ur::cuda::adapter));
142138
};
143139

144140
~ur_context_handle_t_() {
@@ -148,6 +144,7 @@ struct ur_context_handle_t_ {
148144
if (MemoryProviderHost) {
149145
umfMemoryProviderDestroy(MemoryProviderHost);
150146
}
147+
UR_CHECK_ERROR(urAdapterRelease(ur::cuda::adapter));
151148
}
152149

153150
void invokeExtendedDeleters() {

unified-runtime/source/adapters/cuda/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
12821282

12831283
// Get list of platforms
12841284
uint32_t NumPlatforms = 0;
1285-
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get();
1285+
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter;
12861286
ur_result_t Result =
12871287
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
12881288
if (Result != UR_RESULT_SUCCESS)

unified-runtime/source/adapters/cuda/platform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
100100
return ReturnValue(UR_PLATFORM_BACKEND_CUDA);
101101
}
102102
case UR_PLATFORM_INFO_ADAPTER: {
103-
return ReturnValue(ur::cuda::adapter.get());
103+
return ReturnValue(ur::cuda::adapter);
104104
}
105105
default:
106106
return UR_RESULT_ERROR_INVALID_ENUMERATION;

unified-runtime/source/adapters/cuda/usm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
242242

243243
// cuda backend has only one platform containing all devices
244244
ur_platform_handle_t platform;
245-
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get();
245+
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter;
246246
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);
247247

248248
// get the device from the platform

0 commit comments

Comments
 (0)