Skip to content

Commit fe337f8

Browse files
committed
Make Hip also use a dumb pointer.
1 parent dc2a5c0 commit fe337f8

File tree

6 files changed

+20
-16
lines changed

6 files changed

+20
-16
lines changed

unified-runtime/source/adapters/hip/adapter.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include <memory>
1717

1818
namespace ur::hip {
19-
std::shared_ptr<ur_adapter_handle_t_> adapter;
19+
ur_adapter_handle_t adapter;
2020
}
2121

2222
class ur_legacy_sink : public logger::Sink {
@@ -54,12 +54,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
5454
uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) {
5555
if (phAdapters) {
5656
static std::once_flag InitFlag;
57-
std::call_once(InitFlag, [=]() {
58-
ur::hip::adapter = std::make_shared<ur_adapter_handle_t_>();
59-
});
57+
std::call_once(InitFlag,
58+
[=]() { ur::hip::adapter = new ur_adapter_handle_t_; });
6059

61-
ur::hip::adapter->RefCount++;
62-
*phAdapters = ur::hip::adapter.get();
60+
*phAdapters = ur::hip::adapter;
6361
}
6462
if (pNumAdapters) {
6563
*pNumAdapters = 1;
@@ -69,8 +67,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
6967
}
7068

7169
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
72-
// No state to clean up so we don't need to check for 0 references
73-
ur::hip::adapter->RefCount--;
70+
if (--ur::hip::adapter->RefCount == 0) {
71+
delete ur::hip::adapter;
72+
}
73+
7474
return UR_RESULT_SUCCESS;
7575
}
7676

unified-runtime/source/adapters/hip/adapter.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct ur_adapter_handle_t_ {
2525
};
2626

2727
namespace ur::hip {
28-
extern std::shared_ptr<ur_adapter_handle_t_> adapter;
28+
extern ur_adapter_handle_t adapter;
2929
}
3030

3131
#endif // UR_HIP_ADAPTER_HPP_INCLUDED

unified-runtime/source/adapters/hip/context.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,20 @@ struct ur_context_handle_t_ {
8686
void operator()() { Function(UserData); }
8787
};
8888

89-
std::shared_ptr<ur_adapter_handle_t_> Adapter;
9089
std::vector<ur_device_handle_t> Devices;
9190

9291
std::atomic_uint32_t RefCount;
9392

9493
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
95-
: Adapter(ur::hip::adapter), Devices{Devs, Devs + NumDevices},
96-
RefCount{1} {};
94+
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
95+
UR_CHECK_ERROR(urAdapterRetain(ur::hip::adapter));
96+
};
97+
98+
~ur_context_handle_t_() {
99+
UR_CHECK_ERROR(urAdapterRelease(ur::hip::adapter));
100+
}
97101

98-
~ur_context_handle_t_() {}
102+
ur_context_handle_t_(const ur_context_handle_t_ &) = delete;
99103

100104
void invokeExtendedDeleters() {
101105
std::lock_guard<std::mutex> Guard(Mutex);

unified-runtime/source/adapters/hip/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
11791179

11801180
// Get list of platforms
11811181
uint32_t NumPlatforms = 0;
1182-
ur_adapter_handle_t AdapterHandle = ur::hip::adapter.get();
1182+
ur_adapter_handle_t AdapterHandle = ur::hip::adapter;
11831183
ur_result_t Result =
11841184
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
11851185
if (Result != UR_RESULT_SUCCESS)

unified-runtime/source/adapters/hip/platform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName,
3636
return ReturnValue("");
3737
}
3838
case UR_PLATFORM_INFO_ADAPTER: {
39-
return ReturnValue(ur::hip::adapter.get());
39+
return ReturnValue(ur::hip::adapter);
4040
}
4141
default:
4242
return UR_RESULT_ERROR_INVALID_ENUMERATION;

unified-runtime/source/adapters/hip/usm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
199199

200200
// hip backend has only one platform containing all devices
201201
ur_platform_handle_t platform;
202-
ur_adapter_handle_t AdapterHandle = ur::hip::adapter.get();
202+
ur_adapter_handle_t AdapterHandle = ur::hip::adapter;
203203
UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr));
204204

205205
// get the device from the platform

0 commit comments

Comments
 (0)