Skip to content

Commit 2908cf8

Browse files
committed
Use std::unique_ptr for DynamicEventPool members.
1 parent 7771f4e commit 2908cf8

File tree

1 file changed

+75
-35
lines changed

1 file changed

+75
-35
lines changed

mlir/lib/ExecutionEngine/LevelZeroRuntimeWrappers.cpp

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ static ze_driver_handle_t getDriver(uint32_t idx = 0) {
6161
driver_type.pNext = nullptr;
6262
uint32_t driverCount{0};
6363
thread_local static std::vector<ze_driver_handle_t> drivers;
64-
6564
thread_local static bool isDriverInitialised{false};
66-
if (isDriverInitialised)
65+
if (isDriverInitialised && idx < drivers.size())
6766
return drivers[idx];
6867
L0_SAFE_CALL(zeInitDrivers(&driverCount, nullptr, &driver_type));
6968
if (!driverCount)
@@ -83,7 +82,8 @@ static ze_device_handle_t getDefaultDevice(const uint32_t driverIdx = 0,
8382
const int32_t devIdx = 0) {
8483
thread_local static ze_device_handle_t l0Device;
8584
thread_local static int32_t currDevIdx{-1};
86-
if (devIdx == currDevIdx)
85+
thread_local static uint32_t currDriverIdx{0};
86+
if (currDriverIdx == driverIdx && currDevIdx == devIdx)
8787
return l0Device;
8888
auto driver = getDriver(driverIdx);
8989
uint32_t deviceCount{0};
@@ -96,6 +96,7 @@ static ze_device_handle_t getDefaultDevice(const uint32_t driverIdx = 0,
9696
std::vector<ze_device_handle_t> devices(deviceCount);
9797
L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, devices.data()));
9898
l0Device = devices[devIdx];
99+
currDriverIdx = driverIdx;
99100
currDevIdx = devIdx;
100101
return l0Device;
101102
}
@@ -130,20 +131,18 @@ struct ZeCommandListDeleter {
130131
L0_SAFE_CALL(zeCommandListDestroy(cmdList));
131132
}
132133
};
133-
134+
using UniqueZeContext =
135+
std::unique_ptr<std::remove_pointer<ze_context_handle_t>::type,
136+
ZeContextDeleter>;
137+
using UniqueZeCommandList =
138+
std::unique_ptr<std::remove_pointer<ze_command_list_handle_t>::type,
139+
ZeCommandListDeleter>;
134140
struct L0RtContext {
135141
ze_driver_handle_t driver{nullptr};
136142
ze_device_handle_t device{nullptr};
137-
using UniqueZeContext =
138-
std::unique_ptr<std::remove_pointer<ze_context_handle_t>::type,
139-
ZeContextDeleter>;
140143
UniqueZeContext context;
141-
142144
// Usually, one immediate command list with ordinal 0 suffices for
143145
// both copy and compute ops, but leaves HW underutilized.
144-
using UniqueZeCommandList =
145-
std::unique_ptr<std::remove_pointer<ze_command_list_handle_t>::type,
146-
ZeCommandListDeleter>;
147146
UniqueZeCommandList immCmdListCompute;
148147
// Copy engines can be used for both memcpy and memset, but
149148
// they have limitations for memset pattern size (e.g., 1 byte).
@@ -219,13 +218,37 @@ struct L0RtContext {
219218
~L0RtContext() = default;
220219
};
221220

221+
struct ZeEventDeleter {
222+
void operator()(ze_event_handle_t event) const {
223+
if (event)
224+
L0_SAFE_CALL(zeEventDestroy(event));
225+
}
226+
};
227+
228+
struct ZeEventPoolDeleter {
229+
void operator()(ze_event_pool_handle_t pool) const {
230+
if (pool)
231+
L0_SAFE_CALL(zeEventPoolDestroy(pool));
232+
}
233+
};
234+
235+
using UniqueZeEvent =
236+
std::unique_ptr<std::remove_pointer<ze_event_handle_t>::type,
237+
ZeEventDeleter>;
238+
using UniqueZeEventPool =
239+
std::unique_ptr<std::remove_pointer<ze_event_pool_handle_t>::type,
240+
ZeEventPoolDeleter>;
241+
222242
// L0 only supports pre-determined sizes of event pools,
223-
// implement a rt data struct to avoid running out of events.
243+
// implement a runtime data structure to avoid running out of events.
244+
224245
struct DynamicEventPool {
225246
constexpr static size_t numEventsPerPool{128};
226-
std::vector<ze_event_pool_handle_t> eventPools;
227-
std::vector<ze_event_handle_t> availableEvents;
228-
std::unordered_set<ze_event_handle_t> takenEvents;
247+
248+
std::vector<UniqueZeEventPool> eventPools;
249+
std::vector<UniqueZeEvent> availableEvents;
250+
std::unordered_map<ze_event_handle_t, UniqueZeEvent> takenEvents;
251+
229252
size_t currentEventsLimit{0};
230253
size_t currentEventsCnt{0};
231254
L0RtContext *rtCtx;
@@ -234,51 +257,68 @@ struct DynamicEventPool {
234257
createNewPool(numEventsPerPool);
235258
}
236259

260+
DynamicEventPool(const DynamicEventPool &) = delete;
261+
DynamicEventPool &operator=(const DynamicEventPool &) = delete;
262+
263+
// Allow move
264+
DynamicEventPool(DynamicEventPool &&) noexcept = default;
265+
DynamicEventPool &operator=(DynamicEventPool &&) noexcept = default;
266+
237267
~DynamicEventPool() {
238-
assert(!takenEvents.size());
239-
// zeEventDestroy will trigger L0_SAFE_CALL if an event is still used by
240-
// device
241-
for (auto event : availableEvents)
242-
L0_SAFE_CALL(zeEventDestroy(event));
243-
for (auto pool : eventPools)
244-
L0_SAFE_CALL(zeEventPoolDestroy(pool));
268+
assert(takenEvents.empty() && "Some events were not released");
245269
}
246270

247271
void createNewPool(size_t numEvents) {
248272
ze_event_pool_desc_t eventPoolDesc = {};
249273
eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE;
250274
eventPoolDesc.count = numEvents;
251-
eventPools.push_back(nullptr);
275+
276+
ze_event_pool_handle_t rawPool = nullptr;
252277
L0_SAFE_CALL(zeEventPoolCreate(rtCtx->context.get(), &eventPoolDesc, 1,
253-
&rtCtx->device, &eventPools.back()));
278+
&rtCtx->device, &rawPool));
279+
280+
eventPools.emplace_back(UniqueZeEventPool(rawPool));
254281
currentEventsLimit += numEvents;
255282
}
256283

257284
ze_event_handle_t takeEvent() {
258-
ze_event_handle_t event{nullptr};
259-
if (availableEvents.size()) {
260-
event = availableEvents.back();
285+
ze_event_handle_t rawEvent = nullptr;
286+
287+
if (!availableEvents.empty()) {
288+
// Reuse one
289+
auto uniqueEvent = std::move(availableEvents.back());
261290
availableEvents.pop_back();
291+
rawEvent = uniqueEvent.get();
292+
takenEvents[rawEvent] = std::move(uniqueEvent);
262293
} else {
263294
if (currentEventsCnt == currentEventsLimit)
264295
createNewPool(numEventsPerPool);
265-
currentEventsCnt++;
296+
266297
ze_event_desc_t eventDesc = {
267298
ZE_STRUCTURE_TYPE_EVENT_DESC, nullptr,
268299
static_cast<uint32_t>(currentEventsCnt % numEventsPerPool),
269300
ZE_EVENT_SCOPE_FLAG_DEVICE, ZE_EVENT_SCOPE_FLAG_HOST};
270-
L0_SAFE_CALL(zeEventCreate(eventPools.back(), &eventDesc, &event));
301+
302+
ze_event_handle_t newEvent = nullptr;
303+
L0_SAFE_CALL(
304+
zeEventCreate(eventPools.back().get(), &eventDesc, &newEvent));
305+
306+
takenEvents[newEvent] = UniqueZeEvent(newEvent);
307+
rawEvent = newEvent;
308+
currentEventsCnt++;
271309
}
272-
takenEvents.insert(event);
273-
return event;
310+
311+
return rawEvent;
274312
}
275313

276314
void releaseEvent(ze_event_handle_t event) {
277-
auto found = takenEvents.find(event);
278-
assert(found != takenEvents.end());
279-
takenEvents.erase(found);
315+
auto it = takenEvents.find(event);
316+
assert(it != takenEvents.end() &&
317+
"Attempting to release unknown or already released event");
318+
280319
L0_SAFE_CALL(zeEventHostReset(event));
281-
availableEvents.push_back(event);
320+
availableEvents.emplace_back(std::move(it->second));
321+
takenEvents.erase(it);
282322
}
283323
};
284324

0 commit comments

Comments
 (0)