@@ -61,9 +61,8 @@ static ze_driver_handle_t getDriver(uint32_t idx = 0) {
6161 driver_type.pNext = nullptr ;
6262 uint32_t driverCount{0 };
6363 thread_local static std::vector<ze_driver_handle_t > drivers;
64-
6564 thread_local static bool isDriverInitialised{false };
66- if (isDriverInitialised)
65+ if (isDriverInitialised && idx < drivers. size () )
6766 return drivers[idx];
6867 L0_SAFE_CALL (zeInitDrivers (&driverCount, nullptr , &driver_type));
6968 if (!driverCount)
@@ -83,7 +82,8 @@ static ze_device_handle_t getDefaultDevice(const uint32_t driverIdx = 0,
8382 const int32_t devIdx = 0 ) {
8483 thread_local static ze_device_handle_t l0Device;
8584 thread_local static int32_t currDevIdx{-1 };
86- if (devIdx == currDevIdx)
85+ thread_local static uint32_t currDriverIdx{0 };
86+ if (currDriverIdx == driverIdx && currDevIdx == devIdx)
8787 return l0Device;
8888 auto driver = getDriver (driverIdx);
8989 uint32_t deviceCount{0 };
@@ -96,6 +96,7 @@ static ze_device_handle_t getDefaultDevice(const uint32_t driverIdx = 0,
9696 std::vector<ze_device_handle_t > devices (deviceCount);
9797 L0_SAFE_CALL (zeDeviceGet (driver, &deviceCount, devices.data ()));
9898 l0Device = devices[devIdx];
99+ currDriverIdx = driverIdx;
99100 currDevIdx = devIdx;
100101 return l0Device;
101102}
@@ -130,20 +131,18 @@ struct ZeCommandListDeleter {
130131 L0_SAFE_CALL (zeCommandListDestroy (cmdList));
131132 }
132133};
133-
134+ using UniqueZeContext =
135+ std::unique_ptr<std::remove_pointer<ze_context_handle_t >::type,
136+ ZeContextDeleter>;
137+ using UniqueZeCommandList =
138+ std::unique_ptr<std::remove_pointer<ze_command_list_handle_t >::type,
139+ ZeCommandListDeleter>;
134140struct L0RtContext {
135141 ze_driver_handle_t driver{nullptr };
136142 ze_device_handle_t device{nullptr };
137- using UniqueZeContext =
138- std::unique_ptr<std::remove_pointer<ze_context_handle_t >::type,
139- ZeContextDeleter>;
140143 UniqueZeContext context;
141-
142144 // Usually, one immediate command list with ordinal 0 suffices for
143145 // both copy and compute ops, but leaves HW underutilized.
144- using UniqueZeCommandList =
145- std::unique_ptr<std::remove_pointer<ze_command_list_handle_t >::type,
146- ZeCommandListDeleter>;
147146 UniqueZeCommandList immCmdListCompute;
148147 // Copy engines can be used for both memcpy and memset, but
149148 // they have limitations for memset pattern size (e.g., 1 byte).
@@ -219,13 +218,37 @@ struct L0RtContext {
219218 ~L0RtContext () = default ;
220219};
221220
221+ struct ZeEventDeleter {
222+ void operator ()(ze_event_handle_t event) const {
223+ if (event)
224+ L0_SAFE_CALL (zeEventDestroy (event));
225+ }
226+ };
227+
228+ struct ZeEventPoolDeleter {
229+ void operator ()(ze_event_pool_handle_t pool) const {
230+ if (pool)
231+ L0_SAFE_CALL (zeEventPoolDestroy (pool));
232+ }
233+ };
234+
235+ using UniqueZeEvent =
236+ std::unique_ptr<std::remove_pointer<ze_event_handle_t >::type,
237+ ZeEventDeleter>;
238+ using UniqueZeEventPool =
239+ std::unique_ptr<std::remove_pointer<ze_event_pool_handle_t >::type,
240+ ZeEventPoolDeleter>;
241+
222242// L0 only supports pre-determined sizes of event pools,
223- // implement a rt data struct to avoid running out of events.
243+ // implement a runtime data structure to avoid running out of events.
244+
224245struct DynamicEventPool {
225246 constexpr static size_t numEventsPerPool{128 };
226- std::vector<ze_event_pool_handle_t > eventPools;
227- std::vector<ze_event_handle_t > availableEvents;
228- std::unordered_set<ze_event_handle_t > takenEvents;
247+
248+ std::vector<UniqueZeEventPool> eventPools;
249+ std::vector<UniqueZeEvent> availableEvents;
250+ std::unordered_map<ze_event_handle_t , UniqueZeEvent> takenEvents;
251+
229252 size_t currentEventsLimit{0 };
230253 size_t currentEventsCnt{0 };
231254 L0RtContext *rtCtx;
@@ -234,51 +257,68 @@ struct DynamicEventPool {
234257 createNewPool (numEventsPerPool);
235258 }
236259
260+ DynamicEventPool (const DynamicEventPool &) = delete ;
261+ DynamicEventPool &operator =(const DynamicEventPool &) = delete ;
262+
263+ // Allow move
264+ DynamicEventPool (DynamicEventPool &&) noexcept = default ;
265+ DynamicEventPool &operator =(DynamicEventPool &&) noexcept = default ;
266+
237267 ~DynamicEventPool () {
238- assert (!takenEvents.size ());
239- // zeEventDestroy will trigger L0_SAFE_CALL if an event is still used by
240- // device
241- for (auto event : availableEvents)
242- L0_SAFE_CALL (zeEventDestroy (event));
243- for (auto pool : eventPools)
244- L0_SAFE_CALL (zeEventPoolDestroy (pool));
268+ assert (takenEvents.empty () && " Some events were not released" );
245269 }
246270
247271 void createNewPool (size_t numEvents) {
248272 ze_event_pool_desc_t eventPoolDesc = {};
249273 eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE;
250274 eventPoolDesc.count = numEvents;
251- eventPools.push_back (nullptr );
275+
276+ ze_event_pool_handle_t rawPool = nullptr ;
252277 L0_SAFE_CALL (zeEventPoolCreate (rtCtx->context .get (), &eventPoolDesc, 1 ,
253- &rtCtx->device , &eventPools.back ()));
278+ &rtCtx->device , &rawPool));
279+
280+ eventPools.emplace_back (UniqueZeEventPool (rawPool));
254281 currentEventsLimit += numEvents;
255282 }
256283
257284 ze_event_handle_t takeEvent () {
258- ze_event_handle_t event{nullptr };
259- if (availableEvents.size ()) {
260- event = availableEvents.back ();
285+ ze_event_handle_t rawEvent = nullptr ;
286+
287+ if (!availableEvents.empty ()) {
288+ // Reuse one
289+ auto uniqueEvent = std::move (availableEvents.back ());
261290 availableEvents.pop_back ();
291+ rawEvent = uniqueEvent.get ();
292+ takenEvents[rawEvent] = std::move (uniqueEvent);
262293 } else {
263294 if (currentEventsCnt == currentEventsLimit)
264295 createNewPool (numEventsPerPool);
265- currentEventsCnt++;
296+
266297 ze_event_desc_t eventDesc = {
267298 ZE_STRUCTURE_TYPE_EVENT_DESC, nullptr ,
268299 static_cast <uint32_t >(currentEventsCnt % numEventsPerPool),
269300 ZE_EVENT_SCOPE_FLAG_DEVICE, ZE_EVENT_SCOPE_FLAG_HOST};
270- L0_SAFE_CALL (zeEventCreate (eventPools.back (), &eventDesc, &event));
301+
302+ ze_event_handle_t newEvent = nullptr ;
303+ L0_SAFE_CALL (
304+ zeEventCreate (eventPools.back ().get (), &eventDesc, &newEvent));
305+
306+ takenEvents[newEvent] = UniqueZeEvent (newEvent);
307+ rawEvent = newEvent;
308+ currentEventsCnt++;
271309 }
272- takenEvents. insert (event);
273- return event ;
310+
311+ return rawEvent ;
274312 }
275313
276314 void releaseEvent (ze_event_handle_t event) {
277- auto found = takenEvents.find (event);
278- assert (found != takenEvents.end ());
279- takenEvents.erase (found);
315+ auto it = takenEvents.find (event);
316+ assert (it != takenEvents.end () &&
317+ " Attempting to release unknown or already released event" );
318+
280319 L0_SAFE_CALL (zeEventHostReset (event));
281- availableEvents.push_back (event);
320+ availableEvents.emplace_back (std::move (it->second ));
321+ takenEvents.erase (it);
282322 }
283323};
284324
0 commit comments