@@ -47,34 +47,111 @@ struct ol_device_impl_t {
4747 ol_platform_handle_t Platform, InfoTreeNode &&DevInfo)
4848 : DeviceNum(DeviceNum), Device(Device), Platform(Platform),
4949 Info (std::forward<InfoTreeNode>(DevInfo)) {}
50+
51+ ~ol_device_impl_t () {
52+ assert (!OutstandingQueues.size () &&
53+ " Device object dropped with outstanding queues" );
54+ }
55+
5056 int DeviceNum;
5157 GenericDeviceTy *Device;
5258 ol_platform_handle_t Platform;
5359 InfoTreeNode Info;
60+
61+ llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
62+ std::mutex OutstandingQueuesMutex;
63+
64+ // / If the device has any outstanding queues that are now complete, remove it
65+ // / from the list and return it.
66+ // /
67+ // / Queues may be added to the outstanding queue list by olDestroyQueue if
68+ // / they are destroyed but not completed.
69+ __tgt_async_info *getOutstandingQueue () {
70+ // Not locking the `size()` access is fine here - In the worst case we
71+ // either miss a queue that exists or loop through an empty array after
72+ // taking the lock. Both are sub-optimal but not that bad.
73+ if (OutstandingQueues.size ()) {
74+ std::lock_guard<std::mutex> Lock (OutstandingQueuesMutex);
75+
76+ // As queues are pulled and popped from this list, longer running queues
77+ // naturally bubble to the start of the array. Hence looping backwards.
78+ for (auto Q = OutstandingQueues.rbegin (); Q != OutstandingQueues.rend ();
79+ Q++) {
80+ if (!Device->hasPendingWork (*Q)) {
81+ auto OutstandingQueue = *Q;
82+ *Q = OutstandingQueues.back ();
83+ OutstandingQueues.pop_back ();
84+ return OutstandingQueue;
85+ }
86+ }
87+ }
88+ return nullptr ;
89+ }
90+
91+ // / Complete all pending work for this device and perform any needed cleanup.
92+ // /
93+ // / After calling this function, no liboffload functions should be called with
94+ // / this device handle.
95+ llvm::Error destroy () {
96+ llvm::Error Result = Plugin::success ();
97+ for (auto Q : OutstandingQueues)
98+ if (auto Err = Device->synchronize (Q, /* Release=*/ true ))
99+ Result = llvm::joinErrors (std::move (Result), std::move (Err));
100+ OutstandingQueues.clear ();
101+ return Result;
102+ }
54103};
55104
56105struct ol_platform_impl_t {
57106 ol_platform_impl_t (std::unique_ptr<GenericPluginTy> Plugin,
58107 ol_platform_backend_t BackendType)
59108 : Plugin(std::move(Plugin)), BackendType(BackendType) {}
60109 std::unique_ptr<GenericPluginTy> Plugin;
61- std::vector <ol_device_impl_t > Devices;
110+ llvm::SmallVector< std::unique_ptr <ol_device_impl_t > > Devices;
62111 ol_platform_backend_t BackendType;
112+
113+ // / Complete all pending work for this platform and perform any needed
114+ // / cleanup.
115+ // /
116+ // / After calling this function, no liboffload functions should be called with
117+ // / this platform handle.
118+ llvm::Error destroy () {
119+ llvm::Error Result = Plugin::success ();
120+ for (auto &D : Devices)
121+ if (auto Err = D->destroy ())
122+ Result = llvm::joinErrors (std::move (Result), std::move (Err));
123+
124+ if (auto Res = Plugin->deinit ())
125+ Result = llvm::joinErrors (std::move (Result), std::move (Res));
126+
127+ return Result;
128+ }
63129};
64130
65131struct ol_queue_impl_t {
66132 ol_queue_impl_t (__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
67- : AsyncInfo(AsyncInfo), Device(Device) {}
133+ : AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
68134 __tgt_async_info *AsyncInfo;
69135 ol_device_handle_t Device;
136+ // A unique identifier for the queue
137+ size_t Id;
138+ static std::atomic<size_t > IdCounter;
70139};
140+ std::atomic<size_t > ol_queue_impl_t::IdCounter (0 );
71141
72142struct ol_event_impl_t {
73- ol_event_impl_t (void *EventInfo, ol_queue_handle_t Queue)
74- : EventInfo(EventInfo), Queue(Queue) {}
143+ ol_event_impl_t (void *EventInfo, ol_device_handle_t Device,
144+ ol_queue_handle_t Queue)
145+ : EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) {
146+ }
75147 // EventInfo may be null, in which case the event should be considered always
76148 // complete
77149 void *EventInfo;
150+ ol_device_handle_t Device;
151+ size_t QueueId;
152+ // Events may outlive the queue - don't assume this is always valid.
153+ // It is provided only to implement OL_EVENT_INFO_QUEUE. Use QueueId to check
154+ // for queue equality instead.
78155 ol_queue_handle_t Queue;
79156};
80157
@@ -131,7 +208,7 @@ struct OffloadContext {
131208
132209 ol_device_handle_t HostDevice () {
133210 // The host platform is always inserted last
134- return & Platforms.back ().Devices [0 ];
211+ return Platforms.back ().Devices [0 ]. get () ;
135212 }
136213
137214 static OffloadContext &get () {
@@ -190,16 +267,17 @@ Error initPlugins(OffloadContext &Context) {
190267 auto Info = Device->obtainInfoImpl ();
191268 if (auto Err = Info.takeError ())
192269 return Err;
193- Platform.Devices .emplace_back (DevNum, Device, &Platform,
194- std::move (*Info));
270+ Platform.Devices .emplace_back (std::make_unique< ol_device_impl_t >(
271+ DevNum, Device, &Platform, std::move (*Info) ));
195272 }
196273 }
197274 }
198275
199276 // Add the special host device
200277 auto &HostPlatform = Context.Platforms .emplace_back (
201278 ol_platform_impl_t {nullptr , OL_PLATFORM_BACKEND_HOST});
202- HostPlatform.Devices .emplace_back (-1 , nullptr , nullptr , InfoTreeNode{});
279+ HostPlatform.Devices .emplace_back (
280+ std::make_unique<ol_device_impl_t >(-1 , nullptr , nullptr , InfoTreeNode{}));
203281 Context.HostDevice ()->Platform = &HostPlatform;
204282
205283 Context.TracingEnabled = std::getenv (" OFFLOAD_TRACE" );
@@ -240,7 +318,7 @@ Error olShutDown_impl() {
240318 if (!P.Plugin || !P.Plugin ->is_initialized ())
241319 continue ;
242320
243- if (auto Res = P.Plugin -> deinit ())
321+ if (auto Res = P.destroy ())
244322 Result = llvm::joinErrors (std::move (Result), std::move (Res));
245323 }
246324
@@ -508,7 +586,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
508586Error olIterateDevices_impl (ol_device_iterate_cb_t Callback, void *UserData) {
509587 for (auto &Platform : OffloadContext::get ().Platforms ) {
510588 for (auto &Device : Platform.Devices ) {
511- if (!Callback (& Device, UserData)) {
589+ if (!Callback (Device. get () , UserData)) {
512590 break ;
513591 }
514592 }
@@ -569,14 +647,46 @@ Error olMemFree_impl(void *Address) {
569647
570648Error olCreateQueue_impl (ol_device_handle_t Device, ol_queue_handle_t *Queue) {
571649 auto CreatedQueue = std::make_unique<ol_queue_impl_t >(nullptr , Device);
572- if (auto Err = Device->Device ->initAsyncInfo (&(CreatedQueue->AsyncInfo )))
650+
651+ auto OutstandingQueue = Device->getOutstandingQueue ();
652+ if (OutstandingQueue) {
653+ // The queue is empty, but we still need to sync it to release any temporary
654+ // memory allocations or do other cleanup.
655+ if (auto Err =
656+ Device->Device ->synchronize (OutstandingQueue, /* Release=*/ false ))
657+ return Err;
658+ CreatedQueue->AsyncInfo = OutstandingQueue;
659+ } else if (auto Err =
660+ Device->Device ->initAsyncInfo (&(CreatedQueue->AsyncInfo ))) {
573661 return Err;
662+ }
574663
575664 *Queue = CreatedQueue.release ();
576665 return Error::success ();
577666}
578667
579- Error olDestroyQueue_impl (ol_queue_handle_t Queue) { return olDestroy (Queue); }
668+ Error olDestroyQueue_impl (ol_queue_handle_t Queue) {
669+ auto *Device = Queue->Device ;
670+ // This is safe; as soon as olDestroyQueue is called it is not possible to add
671+ // any more work to the queue, so if it's finished now it will remain finished
672+ // forever.
673+ auto Res = Device->Device ->hasPendingWork (Queue->AsyncInfo );
674+ if (!Res)
675+ return Res.takeError ();
676+
677+ if (!*Res) {
678+ // The queue is complete, so sync it and throw it back into the pool.
679+ if (auto Err = Device->Device ->synchronize (Queue->AsyncInfo ,
680+ /* Release=*/ true ))
681+ return Err;
682+ } else {
683+ // The queue still has outstanding work. Store it so we can check it later.
684+ std::lock_guard<std::mutex> Lock (Device->OutstandingQueuesMutex );
685+ Device->OutstandingQueues .push_back (Queue->AsyncInfo );
686+ }
687+
688+ return olDestroy (Queue);
689+ }
580690
581691Error olSyncQueue_impl (ol_queue_handle_t Queue) {
582692 // Host plugin doesn't have a queue set so it's not safe to call synchronize
@@ -604,7 +714,7 @@ Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events,
604714 " olWaitEvents asked to wait on a NULL event" );
605715
606716 // Do nothing if the event is for this queue or the event is always complete
607- if (Event->Queue == Queue || !Event->EventInfo )
717+ if (Event->QueueId == Queue-> Id || !Event->EventInfo )
608718 continue ;
609719
610720 if (auto Err = Device->waitEvent (Event->EventInfo , Queue->AsyncInfo ))
@@ -652,15 +762,15 @@ Error olSyncEvent_impl(ol_event_handle_t Event) {
652762 if (!Event->EventInfo )
653763 return Plugin::success ();
654764
655- if (auto Res = Event->Queue -> Device ->Device ->syncEvent (Event->EventInfo ))
765+ if (auto Res = Event->Device ->Device ->syncEvent (Event->EventInfo ))
656766 return Res;
657767
658768 return Error::success ();
659769}
660770
661771Error olDestroyEvent_impl (ol_event_handle_t Event) {
662772 if (Event->EventInfo )
663- if (auto Res = Event->Queue -> Device ->Device ->destroyEvent (Event->EventInfo ))
773+ if (auto Res = Event->Device ->Device ->destroyEvent (Event->EventInfo ))
664774 return Res;
665775
666776 return olDestroy (Event);
@@ -711,7 +821,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
711821 if (auto Err = Pending.takeError ())
712822 return Err;
713823
714- *EventOut = new ol_event_impl_t (nullptr , Queue);
824+ *EventOut = new ol_event_impl_t (nullptr , Queue-> Device , Queue );
715825 if (!*Pending)
716826 // Queue is empty, don't record an event and consider the event always
717827 // complete
0 commit comments