Skip to content

Commit 2f93477

Browse files
committed
Address review comments.
Limit the maximum no. of events. Modify getContext to be a bit more generic.
1 parent 49647b7 commit 2f93477

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

mlir/lib/ExecutionEngine/LevelZeroRuntimeWrappers.cpp

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,12 @@ static ze_device_handle_t getDevice(const uint32_t driverIdx = 0,
102102
}
103103

104104
// Returns the default L0 context of the defult driver.
105-
static ze_context_handle_t getDefaultContext() {
105+
static ze_context_handle_t getContext(ze_driver_handle_t driver) {
106106
thread_local static ze_context_handle_t context;
107107
thread_local static bool isContextInitialised{false};
108108
if (isContextInitialised)
109109
return context;
110110
ze_context_desc_t ctxtDesc = {ZE_STRUCTURE_TYPE_CONTEXT_DESC, nullptr, 0};
111-
auto driver = getDriver();
112111
L0_SAFE_CALL(zeContextCreate(driver, &ctxtDesc, &context));
113112
isContextInitialised = true;
114113
return context;
@@ -137,7 +136,7 @@ using UniqueZeContext =
137136
using UniqueZeCommandList =
138137
std::unique_ptr<std::remove_pointer<ze_command_list_handle_t>::type,
139138
ZeCommandListDeleter>;
140-
struct L0RtContext {
139+
struct L0RTContextWrapper {
141140
ze_driver_handle_t driver{nullptr};
142141
ze_device_handle_t device{nullptr};
143142
UniqueZeContext context;
@@ -149,12 +148,12 @@ struct L0RtContext {
149148
UniqueZeCommandList immCmdListCopy;
150149
uint32_t copyEngineMaxMemoryFillPatternSize{-1u};
151150

152-
L0RtContext() = default;
153-
L0RtContext(const uint32_t driverIdx = 0, const int32_t devIdx = 0)
151+
L0RTContextWrapper() = default;
152+
L0RTContextWrapper(const uint32_t driverIdx = 0, const int32_t devIdx = 0)
154153
: driver(getDriver(driverIdx)), device(getDevice(devIdx)) {
155154
// Create context
156-
ze_context_handle_t defaultCtx = getDefaultContext();
157-
context.reset(defaultCtx);
155+
ze_context_handle_t ctx = getContext(driver);
156+
context.reset(ctx);
158157

159158
// Determine ordinals
160159
uint32_t computeEngineOrdinal = -1u, copyEngineOrdinal = -1u;
@@ -210,12 +209,12 @@ struct L0RtContext {
210209
context.get(), device, &cmdQueueDesc, &rawCmdListCompute));
211210
immCmdListCompute.reset(rawCmdListCompute);
212211
}
213-
L0RtContext(const L0RtContext &) = delete;
214-
L0RtContext &operator=(const L0RtContext &) = delete;
212+
L0RTContextWrapper(const L0RTContextWrapper &) = delete;
213+
L0RTContextWrapper &operator=(const L0RTContextWrapper &) = delete;
215214
// Allow move
216-
L0RtContext(L0RtContext &&) noexcept = default;
217-
L0RtContext &operator=(L0RtContext &&) noexcept = default;
218-
~L0RtContext() = default;
215+
L0RTContextWrapper(L0RTContextWrapper &&) noexcept = default;
216+
L0RTContextWrapper &operator=(L0RTContextWrapper &&) noexcept = default;
217+
~L0RTContextWrapper() = default;
219218
};
220219

221220
struct ZeEventDeleter {
@@ -249,11 +248,15 @@ struct DynamicEventPool {
249248
std::vector<UniqueZeEvent> availableEvents;
250249
std::unordered_map<ze_event_handle_t, UniqueZeEvent> takenEvents;
251250

251+
// Limit the number of events to avoid running out of memory.
252+
// The limit is set to 32K events, which should be sufficient for most use
253+
// cases.
254+
size_t maxEventsCount{32768}; // 32K events
252255
size_t currentEventsLimit{0};
253256
size_t currentEventsCnt{0};
254-
L0RtContext *rtCtx;
257+
L0RTContextWrapper *rtCtx;
255258

256-
DynamicEventPool(L0RtContext *rtCtx) : rtCtx(rtCtx) {
259+
DynamicEventPool(L0RTContextWrapper *rtCtx) : rtCtx(rtCtx) {
257260
createNewPool(numEventsPerPool);
258261
}
259262

@@ -291,6 +294,9 @@ struct DynamicEventPool {
291294
rawEvent = uniqueEvent.get();
292295
takenEvents[rawEvent] = std::move(uniqueEvent);
293296
} else {
297+
if (currentEventsCnt >= maxEventsCount) {
298+
throw std::runtime_error("DynamicEventPool: reached max events limit");
299+
}
294300
if (currentEventsCnt == currentEventsLimit)
295301
createNewPool(numEventsPerPool);
296302

@@ -322,8 +328,8 @@ struct DynamicEventPool {
322328
}
323329
};
324330

325-
L0RtContext &getRtContext() {
326-
thread_local static L0RtContext rtContext(0);
331+
L0RTContextWrapper &getRtContext() {
332+
thread_local static L0RTContextWrapper rtContext(0);
327333
return rtContext;
328334
}
329335

@@ -488,7 +494,7 @@ extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
488494
template <typename PATTERN_TYPE>
489495
void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count,
490496
StreamWrapper *stream) {
491-
L0RtContext &rtContext = getRtContext();
497+
L0RTContextWrapper &rtContext = getRtContext();
492498
auto listType =
493499
rtContext.copyEngineMaxMemoryFillPatternSize >= sizeof(PATTERN_TYPE)
494500
? rtContext.immCmdListCopy.get()
@@ -561,7 +567,7 @@ extern "C" void mgpuSetDefaultDevice(int32_t devIdx) {
561567
catchAll([&]() {
562568
// For now, a user must ensure that streams and events complete
563569
// and are destroyed before switching a device.
564-
getRtContext() = L0RtContext(devIdx);
570+
getRtContext() = L0RTContextWrapper(devIdx);
565571
getDynamicEventPool() = DynamicEventPool(&getRtContext());
566572
});
567573
}

0 commit comments

Comments
 (0)