Skip to content

Commit 322ae88

Browse files
committed
Apply similar fix to hip adapter.
1 parent 61b3bdb commit 322ae88

File tree

6 files changed

+51
-32
lines changed

6 files changed

+51
-32
lines changed

unified-runtime/source/adapters/hip/adapter.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,14 @@
1010

1111
#include "adapter.hpp"
1212
#include "common.hpp"
13-
#include "logger/ur_logger.hpp"
1413

15-
#include <atomic>
1614
#include <ur_api.h>
1715

18-
struct ur_adapter_handle_t_ {
19-
std::atomic<uint32_t> RefCount = 0;
20-
logger::Logger &logger;
21-
ur_adapter_handle_t_();
22-
};
16+
#include <memory>
17+
18+
namespace ur::hip {
19+
std::shared_ptr<ur_adapter_handle_t_> adapter;
20+
}
2321

2422
class ur_legacy_sink : public logger::Sink {
2523
public:
@@ -42,7 +40,7 @@ class ur_legacy_sink : public logger::Sink {
4240
ur_adapter_handle_t_::ur_adapter_handle_t_()
4341
: logger(
4442
logger::get_logger("hip", /*default_log_level*/ logger::Level::ERR)) {
45-
43+
Platform = std::make_unique<ur_platform_handle_t_>();
4644
if (std::getenv("UR_LOG_HIP") != nullptr)
4745
return;
4846

@@ -52,13 +50,16 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
5250
}
5351
}
5452

55-
ur_adapter_handle_t_ adapter{};
56-
5753
UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
5854
uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) {
5955
if (phAdapters) {
60-
adapter.RefCount++;
61-
*phAdapters = &adapter;
56+
static std::once_flag InitFlag;
57+
std::call_once(InitFlag, [=]() {
58+
ur::hip::adapter = std::make_shared<ur_adapter_handle_t_>();
59+
});
60+
61+
ur::hip::adapter->RefCount++;
62+
*phAdapters = ur::hip::adapter.get();
6263
}
6364
if (pNumAdapters) {
6465
*pNumAdapters = 1;
@@ -69,12 +70,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
6970

7071
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
7172
// No state to clean up so we don't need to check for 0 references
72-
adapter.RefCount--;
73+
ur::hip::adapter->RefCount--;
7374
return UR_RESULT_SUCCESS;
7475
}
7576

7677
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
77-
adapter.RefCount++;
78+
ur::hip::adapter->RefCount++;
7879
return UR_RESULT_SUCCESS;
7980
}
8081

@@ -96,7 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
9697
case UR_ADAPTER_INFO_BACKEND:
9798
return ReturnValue(UR_ADAPTER_BACKEND_HIP);
9899
case UR_ADAPTER_INFO_REFERENCE_COUNT:
99-
return ReturnValue(adapter.RefCount.load());
100+
return ReturnValue(ur::hip::adapter->RefCount.load());
100101
case UR_ADAPTER_INFO_VERSION:
101102
return ReturnValue(uint32_t{1});
102103
default:

unified-runtime/source/adapters/hip/adapter.hpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11-
struct ur_adapter_handle_t_;
11+
#ifndef UR_HIP_ADAPTER_HPP_INCLUDED
12+
#define UR_HIP_ADAPTER_HPP_INCLUDED
1213

13-
extern ur_adapter_handle_t_ adapter;
14+
#include "logger/ur_logger.hpp"
15+
#include "platform.hpp"
16+
17+
#include <atomic>
18+
#include <memory>
19+
20+
struct ur_adapter_handle_t_ {
21+
std::atomic<uint32_t> RefCount = 1;
22+
logger::Logger &logger;
23+
std::unique_ptr<ur_platform_handle_t_> Platform;
24+
ur_adapter_handle_t_();
25+
};
26+
27+
namespace ur::hip {
28+
extern std::shared_ptr<ur_adapter_handle_t_> adapter;
29+
}
30+
31+
#endif // UR_HIP_ADAPTER_HPP_INCLUDED

unified-runtime/source/adapters/hip/context.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <set>
1313

14+
#include "adapter.hpp"
1415
#include "common.hpp"
1516
#include "device.hpp"
1617
#include "platform.hpp"
@@ -85,16 +86,14 @@ struct ur_context_handle_t_ {
8586
void operator()() { Function(UserData); }
8687
};
8788

89+
std::shared_ptr<ur_adapter_handle_t_> Adapter;
8890
std::vector<ur_device_handle_t> Devices;
8991

9092
std::atomic_uint32_t RefCount;
9193

9294
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
93-
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
94-
for (auto &Dev : Devices) {
95-
urDeviceRetain(Dev);
96-
}
97-
};
95+
: Adapter(ur::hip::adapter), Devices{Devs, Devs + NumDevices},
96+
RefCount{1} {};
9897

9998
~ur_context_handle_t_() {}
10099

unified-runtime/source/adapters/hip/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
11741174

11751175
// Get list of platforms
11761176
uint32_t NumPlatforms = 0;
1177-
ur_adapter_handle_t AdapterHandle = &adapter;
1177+
ur_adapter_handle_t AdapterHandle = ur::hip::adapter.get();
11781178
ur_result_t Result =
11791179
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
11801180
if (Result != UR_RESULT_SUCCESS)

unified-runtime/source/adapters/hip/platform.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName,
3636
return ReturnValue("");
3737
}
3838
case UR_PLATFORM_INFO_ADAPTER: {
39-
return ReturnValue(&adapter);
39+
return ReturnValue(ur::hip::adapter.get());
4040
}
4141
default:
4242
return UR_RESULT_ERROR_INVALID_ENUMERATION;
@@ -56,7 +56,6 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
5656
try {
5757
static std::once_flag InitFlag;
5858
static uint32_t NumPlatforms = 1;
59-
static ur_platform_handle_t_ Platform;
6059

6160
UR_ASSERT(phPlatforms || pNumPlatforms, UR_RESULT_ERROR_INVALID_VALUE);
6261
UR_ASSERT(!phPlatforms || NumEntries > 0, UR_RESULT_ERROR_INVALID_VALUE);
@@ -87,18 +86,20 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
8786

8887
// Use the default stream to record base event counter
8988
UR_CHECK_ERROR(hipEventRecord(EvBase, 0));
90-
Platform.Devices.emplace_back(
91-
new ur_device_handle_t_{Device, EvBase, &Platform, i});
89+
ur::hip::adapter->Platform->Devices.emplace_back(
90+
new ur_device_handle_t_{Device, EvBase,
91+
ur::hip::adapter->Platform.get(), i});
9292

93-
ScopedDevice Active(Platform.Devices.front().get());
93+
ScopedDevice Active(
94+
ur::hip::adapter->Platform->Devices.front().get());
9495
}
9596
} catch (const std::bad_alloc &) {
9697
// Signal out-of-memory situation
97-
Platform.Devices.clear();
98+
ur::hip::adapter->Platform->Devices.clear();
9899
Err = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
99100
} catch (ur_result_t CatchErr) {
100101
// Clear and rethrow to allow retry
101-
Platform.Devices.clear();
102+
ur::hip::adapter->Platform->Devices.clear();
102103
Err = CatchErr;
103104
throw CatchErr;
104105
} catch (...) {
@@ -113,7 +114,7 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
113114
}
114115

115116
if (phPlatforms != nullptr) {
116-
*phPlatforms = &Platform;
117+
*phPlatforms = ur::hip::adapter->Platform.get();
117118
}
118119

119120
return Result;

unified-runtime/source/adapters/hip/usm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
199199

200200
// hip backend has only one platform containing all devices
201201
ur_platform_handle_t platform;
202-
ur_adapter_handle_t AdapterHandle = &adapter;
202+
ur_adapter_handle_t AdapterHandle = ur::hip::adapter.get();
203203
UR_CHECK_ERROR(urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr));
204204

205205
// get the device from the platform

0 commit comments

Comments
 (0)