Skip to content

Commit 23d08af

Browse files
authored
[Offload][NFC] use unique ptrs for platforms (#160888)
Currently, devices store a raw pointer to back to their owning Platform. Platforms are stored directly inside of a vector. Modifying this vector risks invalidating all the platform pointers stored in devices. This patch allocates platforms individually, and changes devices to store a reference to its platform instead of a pointer. This is safe, because platforms are guaranteed to outlive the devices they contain.
1 parent 9d33b99 commit 23d08af

File tree

1 file changed

+45
-44
lines changed

1 file changed

+45
-44
lines changed

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,28 @@ using namespace llvm::omp::target;
3939
using namespace llvm::omp::target::plugin;
4040
using namespace error;
4141

42+
struct ol_platform_impl_t {
43+
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
44+
ol_platform_backend_t BackendType)
45+
: Plugin(std::move(Plugin)), BackendType(BackendType) {}
46+
std::unique_ptr<GenericPluginTy> Plugin;
47+
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
48+
ol_platform_backend_t BackendType;
49+
50+
/// Complete all pending work for this platform and perform any needed
51+
/// cleanup.
52+
///
53+
/// After calling this function, no liboffload functions should be called with
54+
/// this platform handle.
55+
llvm::Error destroy();
56+
};
57+
4258
// Handle type definitions. Ideally these would be 1:1 with the plugins, but
4359
// we add some additional data here for now to avoid churn in the plugin
4460
// interface.
4561
struct ol_device_impl_t {
4662
ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
47-
ol_platform_handle_t Platform, InfoTreeNode &&DevInfo)
63+
ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo)
4864
: DeviceNum(DeviceNum), Device(Device), Platform(Platform),
4965
Info(std::forward<InfoTreeNode>(DevInfo)) {}
5066

@@ -55,7 +71,7 @@ struct ol_device_impl_t {
5571

5672
int DeviceNum;
5773
GenericDeviceTy *Device;
58-
ol_platform_handle_t Platform;
74+
ol_platform_impl_t &Platform;
5975
InfoTreeNode Info;
6076

6177
llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
@@ -102,31 +118,17 @@ struct ol_device_impl_t {
102118
}
103119
};
104120

105-
struct ol_platform_impl_t {
106-
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
107-
ol_platform_backend_t BackendType)
108-
: Plugin(std::move(Plugin)), BackendType(BackendType) {}
109-
std::unique_ptr<GenericPluginTy> Plugin;
110-
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
111-
ol_platform_backend_t BackendType;
112-
113-
/// Complete all pending work for this platform and perform any needed
114-
/// cleanup.
115-
///
116-
/// After calling this function, no liboffload functions should be called with
117-
/// this platform handle.
118-
llvm::Error destroy() {
119-
llvm::Error Result = Plugin::success();
120-
for (auto &D : Devices)
121-
if (auto Err = D->destroy())
122-
Result = llvm::joinErrors(std::move(Result), std::move(Err));
121+
llvm::Error ol_platform_impl_t::destroy() {
122+
llvm::Error Result = Plugin::success();
123+
for (auto &D : Devices)
124+
if (auto Err = D->destroy())
125+
Result = llvm::joinErrors(std::move(Result), std::move(Err));
123126

124-
if (auto Res = Plugin->deinit())
125-
Result = llvm::joinErrors(std::move(Result), std::move(Res));
127+
if (auto Res = Plugin->deinit())
128+
Result = llvm::joinErrors(std::move(Result), std::move(Res));
126129

127-
return Result;
128-
}
129-
};
130+
return Result;
131+
}
130132

131133
struct ol_queue_impl_t {
132134
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
@@ -206,12 +208,12 @@ struct OffloadContext {
206208
// Partitioned list of memory base addresses. Each element in this list is a
207209
// key in AllocInfoMap
208210
llvm::SmallVector<void *> AllocBases{};
209-
SmallVector<ol_platform_impl_t, 4> Platforms{};
211+
SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
210212
size_t RefCount;
211213

212214
ol_device_handle_t HostDevice() {
213215
// The host platform is always inserted last
214-
return Platforms.back().Devices[0].get();
216+
return Platforms.back()->Devices[0].get();
215217
}
216218

217219
static OffloadContext &get() {
@@ -251,35 +253,34 @@ Error initPlugins(OffloadContext &Context) {
251253
#define PLUGIN_TARGET(Name) \
252254
do { \
253255
if (StringRef(#Name) != "host") \
254-
Context.Platforms.emplace_back(ol_platform_impl_t{ \
256+
Context.Platforms.emplace_back(std::make_unique<ol_platform_impl_t>( \
255257
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
256-
pluginNameToBackend(#Name)}); \
258+
pluginNameToBackend(#Name))); \
257259
} while (false);
258260
#include "Shared/Targets.def"
259261

260262
// Preemptively initialize all devices in the plugin
261263
for (auto &Platform : Context.Platforms) {
262-
auto Err = Platform.Plugin->init();
264+
auto Err = Platform->Plugin->init();
263265
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
264-
for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
266+
for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices();
265267
DevNum++) {
266-
if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
267-
auto Device = &Platform.Plugin->getDevice(DevNum);
268+
if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
269+
auto Device = &Platform->Plugin->getDevice(DevNum);
268270
auto Info = Device->obtainInfoImpl();
269271
if (auto Err = Info.takeError())
270272
return Err;
271-
Platform.Devices.emplace_back(std::make_unique<ol_device_impl_t>(
272-
DevNum, Device, &Platform, std::move(*Info)));
273+
Platform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
274+
DevNum, Device, *Platform, std::move(*Info)));
273275
}
274276
}
275277
}
276278

277279
// Add the special host device
278280
auto &HostPlatform = Context.Platforms.emplace_back(
279-
ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
280-
HostPlatform.Devices.emplace_back(
281-
std::make_unique<ol_device_impl_t>(-1, nullptr, nullptr, InfoTreeNode{}));
282-
Context.HostDevice()->Platform = &HostPlatform;
281+
std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST));
282+
HostPlatform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
283+
-1, nullptr, *HostPlatform, InfoTreeNode{}));
283284

284285
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
285286
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
@@ -316,10 +317,10 @@ Error olShutDown_impl() {
316317

317318
for (auto &P : OldContext->Platforms) {
318319
// Host plugin is nullptr and has no deinit
319-
if (!P.Plugin || !P.Plugin->is_initialized())
320+
if (!P->Plugin || !P->Plugin->is_initialized())
320321
continue;
321322

322-
if (auto Res = P.destroy())
323+
if (auto Res = P->destroy())
323324
Result = llvm::joinErrors(std::move(Result), std::move(Res));
324325
}
325326

@@ -384,7 +385,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
384385
// These are not implemented by the plugin interface
385386
switch (PropName) {
386387
case OL_DEVICE_INFO_PLATFORM:
387-
return Info.write<void *>(Device->Platform);
388+
return Info.write<void *>(&Device->Platform);
388389

389390
case OL_DEVICE_INFO_TYPE:
390391
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
@@ -517,7 +518,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
517518

518519
switch (PropName) {
519520
case OL_DEVICE_INFO_PLATFORM:
520-
return Info.write<void *>(Device->Platform);
521+
return Info.write<void *>(&Device->Platform);
521522
case OL_DEVICE_INFO_TYPE:
522523
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
523524
case OL_DEVICE_INFO_NAME:
@@ -595,7 +596,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
595596

596597
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
597598
for (auto &Platform : OffloadContext::get().Platforms) {
598-
for (auto &Device : Platform.Devices) {
599+
for (auto &Device : Platform->Devices) {
599600
if (!Callback(Device.get(), UserData)) {
600601
break;
601602
}

0 commit comments

Comments
 (0)