Skip to content

Commit 9e5d8bd

Browse files
authored
[Offload] Improve olDestroyQueue logic (llvm#153041)
Previously, `olDestroyQueue` would not actually destroy the queue, instead leaving it for the device to clean up when it was destroyed. Now, the queue is either released immediately if it is complete or put into a list of "pending" queues if it is not. Whenever we create a new queue, we check this list to see if any are now completed. If there are any we release their resources and use them instead of pulling from the pool. This prevents long running programs that create and drop many queues without syncing them from leaking memory all over the place.
1 parent 4cf9720 commit 9e5d8bd

File tree

3 files changed

+144
-22
lines changed

3 files changed

+144
-22
lines changed

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 126 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,34 +47,111 @@ struct ol_device_impl_t {
4747
ol_platform_handle_t Platform, InfoTreeNode &&DevInfo)
4848
: DeviceNum(DeviceNum), Device(Device), Platform(Platform),
4949
Info(std::forward<InfoTreeNode>(DevInfo)) {}
50+
51+
~ol_device_impl_t() {
52+
assert(!OutstandingQueues.size() &&
53+
"Device object dropped with outstanding queues");
54+
}
55+
5056
int DeviceNum;
5157
GenericDeviceTy *Device;
5258
ol_platform_handle_t Platform;
5359
InfoTreeNode Info;
60+
61+
llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
62+
std::mutex OutstandingQueuesMutex;
63+
64+
/// If the device has any outstanding queues that are now complete, remove it
65+
/// from the list and return it.
66+
///
67+
/// Queues may be added to the outstanding queue list by olDestroyQueue if
68+
/// they are destroyed but not completed.
69+
__tgt_async_info *getOutstandingQueue() {
70+
// Not locking the `size()` access is fine here - In the worst case we
71+
// either miss a queue that exists or loop through an empty array after
72+
// taking the lock. Both are sub-optimal but not that bad.
73+
if (OutstandingQueues.size()) {
74+
std::lock_guard<std::mutex> Lock(OutstandingQueuesMutex);
75+
76+
// As queues are pulled and popped from this list, longer running queues
77+
// naturally bubble to the start of the array. Hence looping backwards.
78+
for (auto Q = OutstandingQueues.rbegin(); Q != OutstandingQueues.rend();
79+
Q++) {
80+
if (!Device->hasPendingWork(*Q)) {
81+
auto OutstandingQueue = *Q;
82+
*Q = OutstandingQueues.back();
83+
OutstandingQueues.pop_back();
84+
return OutstandingQueue;
85+
}
86+
}
87+
}
88+
return nullptr;
89+
}
90+
91+
/// Complete all pending work for this device and perform any needed cleanup.
92+
///
93+
/// After calling this function, no liboffload functions should be called with
94+
/// this device handle.
95+
llvm::Error destroy() {
96+
llvm::Error Result = Plugin::success();
97+
for (auto Q : OutstandingQueues)
98+
if (auto Err = Device->synchronize(Q, /*Release=*/true))
99+
Result = llvm::joinErrors(std::move(Result), std::move(Err));
100+
OutstandingQueues.clear();
101+
return Result;
102+
}
54103
};
55104

56105
struct ol_platform_impl_t {
57106
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
58107
ol_platform_backend_t BackendType)
59108
: Plugin(std::move(Plugin)), BackendType(BackendType) {}
60109
std::unique_ptr<GenericPluginTy> Plugin;
61-
std::vector<ol_device_impl_t> Devices;
110+
llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
62111
ol_platform_backend_t BackendType;
112+
113+
/// Complete all pending work for this platform and perform any needed
114+
/// cleanup.
115+
///
116+
/// After calling this function, no liboffload functions should be called with
117+
/// this platform handle.
118+
llvm::Error destroy() {
119+
llvm::Error Result = Plugin::success();
120+
for (auto &D : Devices)
121+
if (auto Err = D->destroy())
122+
Result = llvm::joinErrors(std::move(Result), std::move(Err));
123+
124+
if (auto Res = Plugin->deinit())
125+
Result = llvm::joinErrors(std::move(Result), std::move(Res));
126+
127+
return Result;
128+
}
63129
};
64130

65131
struct ol_queue_impl_t {
66132
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
67-
: AsyncInfo(AsyncInfo), Device(Device) {}
133+
: AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
68134
__tgt_async_info *AsyncInfo;
69135
ol_device_handle_t Device;
136+
// A unique identifier for the queue
137+
size_t Id;
138+
static std::atomic<size_t> IdCounter;
70139
};
140+
std::atomic<size_t> ol_queue_impl_t::IdCounter(0);
71141

72142
struct ol_event_impl_t {
73-
ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue)
74-
: EventInfo(EventInfo), Queue(Queue) {}
143+
ol_event_impl_t(void *EventInfo, ol_device_handle_t Device,
144+
ol_queue_handle_t Queue)
145+
: EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) {
146+
}
75147
// EventInfo may be null, in which case the event should be considered always
76148
// complete
77149
void *EventInfo;
150+
ol_device_handle_t Device;
151+
size_t QueueId;
152+
// Events may outlive the queue - don't assume this is always valid.
153+
// It is provided only to implement OL_EVENT_INFO_QUEUE. Use QueueId to check
154+
// for queue equality instead.
78155
ol_queue_handle_t Queue;
79156
};
80157

@@ -131,7 +208,7 @@ struct OffloadContext {
131208

132209
ol_device_handle_t HostDevice() {
133210
// The host platform is always inserted last
134-
return &Platforms.back().Devices[0];
211+
return Platforms.back().Devices[0].get();
135212
}
136213

137214
static OffloadContext &get() {
@@ -190,16 +267,17 @@ Error initPlugins(OffloadContext &Context) {
190267
auto Info = Device->obtainInfoImpl();
191268
if (auto Err = Info.takeError())
192269
return Err;
193-
Platform.Devices.emplace_back(DevNum, Device, &Platform,
194-
std::move(*Info));
270+
Platform.Devices.emplace_back(std::make_unique<ol_device_impl_t>(
271+
DevNum, Device, &Platform, std::move(*Info)));
195272
}
196273
}
197274
}
198275

199276
// Add the special host device
200277
auto &HostPlatform = Context.Platforms.emplace_back(
201278
ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
202-
HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{});
279+
HostPlatform.Devices.emplace_back(
280+
std::make_unique<ol_device_impl_t>(-1, nullptr, nullptr, InfoTreeNode{}));
203281
Context.HostDevice()->Platform = &HostPlatform;
204282

205283
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
@@ -240,7 +318,7 @@ Error olShutDown_impl() {
240318
if (!P.Plugin || !P.Plugin->is_initialized())
241319
continue;
242320

243-
if (auto Res = P.Plugin->deinit())
321+
if (auto Res = P.destroy())
244322
Result = llvm::joinErrors(std::move(Result), std::move(Res));
245323
}
246324

@@ -508,7 +586,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
508586
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
509587
for (auto &Platform : OffloadContext::get().Platforms) {
510588
for (auto &Device : Platform.Devices) {
511-
if (!Callback(&Device, UserData)) {
589+
if (!Callback(Device.get(), UserData)) {
512590
break;
513591
}
514592
}
@@ -569,14 +647,46 @@ Error olMemFree_impl(void *Address) {
569647

570648
Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
571649
auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device);
572-
if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo)))
650+
651+
auto OutstandingQueue = Device->getOutstandingQueue();
652+
if (OutstandingQueue) {
653+
// The queue is empty, but we still need to sync it to release any temporary
654+
// memory allocations or do other cleanup.
655+
if (auto Err =
656+
Device->Device->synchronize(OutstandingQueue, /*Release=*/false))
657+
return Err;
658+
CreatedQueue->AsyncInfo = OutstandingQueue;
659+
} else if (auto Err =
660+
Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) {
573661
return Err;
662+
}
574663

575664
*Queue = CreatedQueue.release();
576665
return Error::success();
577666
}
578667

579-
Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); }
668+
Error olDestroyQueue_impl(ol_queue_handle_t Queue) {
669+
auto *Device = Queue->Device;
670+
// This is safe; as soon as olDestroyQueue is called it is not possible to add
671+
// any more work to the queue, so if it's finished now it will remain finished
672+
// forever.
673+
auto Res = Device->Device->hasPendingWork(Queue->AsyncInfo);
674+
if (!Res)
675+
return Res.takeError();
676+
677+
if (!*Res) {
678+
// The queue is complete, so sync it and throw it back into the pool.
679+
if (auto Err = Device->Device->synchronize(Queue->AsyncInfo,
680+
/*Release=*/true))
681+
return Err;
682+
} else {
683+
// The queue still has outstanding work. Store it so we can check it later.
684+
std::lock_guard<std::mutex> Lock(Device->OutstandingQueuesMutex);
685+
Device->OutstandingQueues.push_back(Queue->AsyncInfo);
686+
}
687+
688+
return olDestroy(Queue);
689+
}
580690

581691
Error olSyncQueue_impl(ol_queue_handle_t Queue) {
582692
// Host plugin doesn't have a queue set so it's not safe to call synchronize
@@ -604,7 +714,7 @@ Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events,
604714
"olWaitEvents asked to wait on a NULL event");
605715

606716
// Do nothing if the event is for this queue or the event is always complete
607-
if (Event->Queue == Queue || !Event->EventInfo)
717+
if (Event->QueueId == Queue->Id || !Event->EventInfo)
608718
continue;
609719

610720
if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo))
@@ -652,15 +762,15 @@ Error olSyncEvent_impl(ol_event_handle_t Event) {
652762
if (!Event->EventInfo)
653763
return Plugin::success();
654764

655-
if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo))
765+
if (auto Res = Event->Device->Device->syncEvent(Event->EventInfo))
656766
return Res;
657767

658768
return Error::success();
659769
}
660770

661771
Error olDestroyEvent_impl(ol_event_handle_t Event) {
662772
if (Event->EventInfo)
663-
if (auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo))
773+
if (auto Res = Event->Device->Device->destroyEvent(Event->EventInfo))
664774
return Res;
665775

666776
return olDestroy(Event);
@@ -711,7 +821,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
711821
if (auto Err = Pending.takeError())
712822
return Err;
713823

714-
*EventOut = new ol_event_impl_t(nullptr, Queue);
824+
*EventOut = new ol_event_impl_t(nullptr, Queue->Device, Queue);
715825
if (!*Pending)
716826
// Queue is empty, don't record an event and consider the event always
717827
// complete

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,16 +1341,19 @@ Error PinnedAllocationMapTy::unlockUnmappedHostBuffer(void *HstPtr) {
13411341

13421342
Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo,
13431343
bool ReleaseQueue) {
1344+
if (!AsyncInfo)
1345+
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
1346+
"invalid async info queue");
1347+
13441348
SmallVector<void *> AllocsToDelete{};
13451349
{
13461350
std::lock_guard<std::mutex> AllocationGuard{AsyncInfo->Mutex};
13471351

1348-
if (!AsyncInfo || !AsyncInfo->Queue)
1349-
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
1350-
"invalid async info queue");
1351-
1352-
if (auto Err = synchronizeImpl(*AsyncInfo, ReleaseQueue))
1353-
return Err;
1352+
// This can be false when no work has been added to the AsyncInfo. In which
1353+
// case, the device has nothing to synchronize.
1354+
if (AsyncInfo->Queue)
1355+
if (auto Err = synchronizeImpl(*AsyncInfo, ReleaseQueue))
1356+
return Err;
13541357

13551358
std::swap(AllocsToDelete, AsyncInfo->AssociatedAllocations);
13561359
}

offload/unittests/OffloadAPI/queue/olDestroyQueue.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ TEST_P(olDestroyQueueTest, Success) {
1818
Queue = nullptr;
1919
}
2020

21+
TEST_P(olDestroyQueueTest, SuccessDelayedResolution) {
22+
ManuallyTriggeredTask Manual;
23+
ASSERT_SUCCESS(Manual.enqueue(Queue));
24+
ASSERT_SUCCESS(olDestroyQueue(Queue));
25+
Queue = nullptr;
26+
27+
ASSERT_SUCCESS(Manual.trigger());
28+
}
29+
2130
TEST_P(olDestroyQueueTest, InvalidNullHandle) {
2231
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olDestroyQueue(nullptr));
2332
}

0 commit comments

Comments
 (0)