Skip to content

Commit 7771f4e

Browse files
committed
Use std::unique_ptr for L0 handles.
Fix Cmake whitespace isseu.
1 parent 8e83c9c commit 7771f4e

File tree

2 files changed

+80
-45
lines changed

2 files changed

+80
-45
lines changed

mlir/lib/ExecutionEngine/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,4 +512,4 @@ if(LLVM_ENABLE_PIC)
512512
${Vulkan_LIBRARY}
513513
)
514514
endif()
515-
endif()
515+
endif()

mlir/lib/ExecutionEngine/LevelZeroRuntimeWrappers.cpp

Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ auto catchAll(F &&func) {
4545
std::abort(); \
4646
} \
4747
}
48-
4948
} // namespace
5049

5150
//===----------------------------------------------------------------------===//
@@ -118,23 +117,49 @@ static ze_context_handle_t getDefaultContext() {
118117
// L0 RT helper structs
119118
//===----------------------------------------------------------------------===//
120119

120+
struct ZeContextDeleter {
121+
void operator()(ze_context_handle_t ctx) const {
122+
if (ctx)
123+
L0_SAFE_CALL(zeContextDestroy(ctx));
124+
}
125+
};
126+
127+
struct ZeCommandListDeleter {
128+
void operator()(ze_command_list_handle_t cmdList) const {
129+
if (cmdList)
130+
L0_SAFE_CALL(zeCommandListDestroy(cmdList));
131+
}
132+
};
133+
121134
struct L0RtContext {
122135
ze_driver_handle_t driver{nullptr};
123136
ze_device_handle_t device{nullptr};
124-
ze_context_handle_t context{nullptr};
137+
using UniqueZeContext =
138+
std::unique_ptr<std::remove_pointer<ze_context_handle_t>::type,
139+
ZeContextDeleter>;
140+
UniqueZeContext context;
141+
125142
// Usually, one immediate command list with ordinal 0 suffices for
126143
// both copy and compute ops, but leaves HW underutilized.
127-
ze_command_list_handle_t immCmdListCompute{nullptr};
144+
using UniqueZeCommandList =
145+
std::unique_ptr<std::remove_pointer<ze_command_list_handle_t>::type,
146+
ZeCommandListDeleter>;
147+
UniqueZeCommandList immCmdListCompute;
128148
// Copy engines can be used for both memcpy and memset, but
129149
// they have limitations for memset pattern size (e.g., 1 byte).
130-
ze_command_list_handle_t immCmdListCopy{nullptr};
150+
UniqueZeCommandList immCmdListCopy;
131151
uint32_t copyEngineMaxMemoryFillPatternSize{-1u};
132152

153+
L0RtContext() = default;
133154
L0RtContext(const int32_t devIdx = 0)
134-
: driver(getDriver()), device(getDefaultDevice(devIdx)),
135-
context(getDefaultContext()) {
155+
: driver(getDriver()), device(getDefaultDevice(devIdx)) {
156+
// Create context
157+
ze_context_handle_t defaultCtx = getDefaultContext();
158+
context.reset(defaultCtx);
159+
160+
// Determine ordinals
136161
uint32_t computeEngineOrdinal = -1u, copyEngineOrdinal = -1u;
137-
ze_device_properties_t deviceProperties = {};
162+
ze_device_properties_t deviceProperties{};
138163
L0_SAFE_CALL(zeDeviceGetProperties(device, &deviceProperties));
139164
uint32_t queueGroupCount = 0;
140165
L0_SAFE_CALL(zeDeviceGetCommandQueueGroupProperties(
@@ -143,6 +168,7 @@ struct L0RtContext {
143168
queueGroupCount);
144169
L0_SAFE_CALL(zeDeviceGetCommandQueueGroupProperties(
145170
device, &queueGroupCount, queueGroupProperties.data()));
171+
146172
for (uint32_t queueGroupIdx = 0; queueGroupIdx < queueGroupCount;
147173
++queueGroupIdx) {
148174
const auto &group = queueGroupProperties[queueGroupIdx];
@@ -155,11 +181,15 @@ struct L0RtContext {
155181
if (copyEngineOrdinal != -1u && computeEngineOrdinal != -1u)
156182
break;
157183
}
184+
158185
// Fallback to the default queue if no dedicated copy queue is available.
159186
if (copyEngineOrdinal == -1u)
160187
copyEngineOrdinal = computeEngineOrdinal;
188+
161189
assert(copyEngineOrdinal != -1u && computeEngineOrdinal != -1u &&
162190
"Expected two engines to be available.");
191+
192+
// Create copy command list
163193
ze_command_queue_desc_t cmdQueueDesc{
164194
ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
165195
nullptr,
@@ -168,18 +198,25 @@ struct L0RtContext {
168198
0, // flags
169199
ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,
170200
ZE_COMMAND_QUEUE_PRIORITY_NORMAL};
171-
L0_SAFE_CALL(zeCommandListCreateImmediate(context, device, &cmdQueueDesc,
172-
&immCmdListCopy));
201+
202+
ze_command_list_handle_t rawCmdListCopy = nullptr;
203+
L0_SAFE_CALL(zeCommandListCreateImmediate(context.get(), device,
204+
&cmdQueueDesc, &rawCmdListCopy));
205+
immCmdListCopy.reset(rawCmdListCopy);
206+
207+
// Create compute command list
173208
cmdQueueDesc.ordinal = computeEngineOrdinal;
174-
L0_SAFE_CALL(zeCommandListCreateImmediate(context, device, &cmdQueueDesc,
175-
&immCmdListCompute));
176-
}
177-
void cleanup() {
178-
L0_SAFE_CALL(zeCommandListDestroy(immCmdListCopy));
179-
L0_SAFE_CALL(zeCommandListDestroy(immCmdListCompute));
180-
L0_SAFE_CALL(zeContextDestroy(context));
209+
ze_command_list_handle_t rawCmdListCompute = nullptr;
210+
L0_SAFE_CALL(zeCommandListCreateImmediate(
211+
context.get(), device, &cmdQueueDesc, &rawCmdListCompute));
212+
immCmdListCompute.reset(rawCmdListCompute);
181213
}
182-
~L0RtContext() { cleanup(); }
214+
L0RtContext(const L0RtContext &) = delete;
215+
L0RtContext &operator=(const L0RtContext &) = delete;
216+
// Allow move
217+
L0RtContext(L0RtContext &&) noexcept = default;
218+
L0RtContext &operator=(L0RtContext &&) noexcept = default;
219+
~L0RtContext() = default;
183220
};
184221

185222
// L0 only supports pre-determined sizes of event pools,
@@ -212,7 +249,7 @@ struct DynamicEventPool {
212249
eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE;
213250
eventPoolDesc.count = numEvents;
214251
eventPools.push_back(nullptr);
215-
L0_SAFE_CALL(zeEventPoolCreate(rtCtx->context, &eventPoolDesc, 1,
252+
L0_SAFE_CALL(zeEventPoolCreate(rtCtx->context.get(), &eventPoolDesc, 1,
216253
&rtCtx->device, &eventPools.back()));
217254
currentEventsLimit += numEvents;
218255
}
@@ -246,7 +283,7 @@ struct DynamicEventPool {
246283
};
247284

248285
L0RtContext &getRtContext() {
249-
thread_local static L0RtContext rtContext;
286+
thread_local static L0RtContext rtContext(0);
250287
return rtContext;
251288
}
252289

@@ -286,13 +323,13 @@ struct StreamWrapper {
286323
implicitEventStack.clear();
287324
}
288325

289-
void enqueueOp(
290-
std::function<void(ze_event_handle_t, uint32_t, ze_event_handle_t *)>
291-
op) {
326+
template <typename Func>
327+
void enqueueOp(Func &&op) {
292328
ze_event_handle_t newImplicitEvent = dynEventPool.takeEvent();
293329
ze_event_handle_t *lastImplicitEventPtr = getLastImplicitEventPtr();
294330
const uint32_t numWaitEvents = lastImplicitEventPtr ? 1 : 0;
295-
op(newImplicitEvent, numWaitEvents, lastImplicitEventPtr);
331+
std::forward<Func>(op)(newImplicitEvent, numWaitEvents,
332+
lastImplicitEventPtr);
296333
implicitEventStack.push_back(newImplicitEvent);
297334
}
298335
};
@@ -309,7 +346,7 @@ static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
309346
nullptr};
310347
ze_module_build_log_handle_t buildLogHandle;
311348
ze_result_t result =
312-
zeModuleCreate(getRtContext().context, getRtContext().device, &desc,
349+
zeModuleCreate(getRtContext().context.get(), getRtContext().device, &desc,
313350
&zeModule, &buildLogHandle);
314351
if (result != ZE_RESULT_SUCCESS) {
315352
std::cerr << "Error creating module, error code: " << result << std::endl;
@@ -337,14 +374,12 @@ extern "C" void mgpuStreamSynchronize(StreamWrapper *stream) {
337374
stream->sync();
338375
}
339376

340-
extern "C" void mgpuStreamDestroy(StreamWrapper *stream) {
341-
if (stream)
342-
delete stream;
343-
}
377+
extern "C" void mgpuStreamDestroy(StreamWrapper *stream) { delete stream; }
344378

345379
extern "C" void mgpuStreamWaitEvent(StreamWrapper *stream,
346380
ze_event_handle_t event) {
347-
assert(stream && event);
381+
assert(stream && "Invalid stream");
382+
assert(event && "Invalid event");
348383
stream->sync(event);
349384
}
350385

@@ -364,10 +399,10 @@ extern "C" void mgpuEventSynchronize(ze_event_handle_t event) {
364399

365400
extern "C" void mgpuEventRecord(ze_event_handle_t event,
366401
StreamWrapper *stream) {
367-
L0_SAFE_CALL(
368-
zeCommandListAppendSignalEvent(getRtContext().immCmdListCopy, event));
369-
L0_SAFE_CALL(
370-
zeCommandListAppendSignalEvent(getRtContext().immCmdListCompute, event));
402+
L0_SAFE_CALL(zeCommandListAppendSignalEvent(
403+
getRtContext().immCmdListCopy.get(), event));
404+
L0_SAFE_CALL(zeCommandListAppendSignalEvent(
405+
getRtContext().immCmdListCompute.get(), event));
371406
}
372407

373408
extern "C" void *mgpuMemAlloc(uint64_t size, StreamWrapper *stream,
@@ -380,12 +415,13 @@ extern "C" void *mgpuMemAlloc(uint64_t size, StreamWrapper *stream,
380415
if (isShared) {
381416
ze_host_mem_alloc_desc_t hostDesc = {};
382417
hostDesc.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC;
383-
L0_SAFE_CALL(zeMemAllocShared(getRtContext().context, &deviceDesc,
418+
L0_SAFE_CALL(zeMemAllocShared(getRtContext().context.get(), &deviceDesc,
384419
&hostDesc, size, alignment,
385420
getRtContext().device, &memPtr));
386421
} else {
387-
L0_SAFE_CALL(zeMemAllocDevice(getRtContext().context, &deviceDesc, size,
388-
alignment, getRtContext().device, &memPtr));
422+
L0_SAFE_CALL(zeMemAllocDevice(getRtContext().context.get(), &deviceDesc,
423+
size, alignment, getRtContext().device,
424+
&memPtr));
389425
}
390426
if (!memPtr)
391427
throw std::runtime_error("mem allocation failed!");
@@ -396,16 +432,16 @@ extern "C" void *mgpuMemAlloc(uint64_t size, StreamWrapper *stream,
396432
extern "C" void mgpuMemFree(void *ptr, StreamWrapper *stream) {
397433
stream->sync();
398434
if (ptr)
399-
L0_SAFE_CALL(zeMemFree(getRtContext().context, ptr));
435+
L0_SAFE_CALL(zeMemFree(getRtContext().context.get(), ptr));
400436
}
401437

402438
extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
403439
StreamWrapper *stream) {
404440
stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
405441
ze_event_handle_t *waitEvents) {
406-
L0_SAFE_CALL(zeCommandListAppendMemoryCopy(getRtContext().immCmdListCopy,
407-
dst, src, sizeBytes, newEvent,
408-
numWaitEvents, waitEvents));
442+
L0_SAFE_CALL(zeCommandListAppendMemoryCopy(
443+
getRtContext().immCmdListCopy.get(), dst, src, sizeBytes, newEvent,
444+
numWaitEvents, waitEvents));
409445
});
410446
}
411447

@@ -414,8 +450,8 @@ void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count,
414450
StreamWrapper *stream) {
415451
auto listType =
416452
getRtContext().copyEngineMaxMemoryFillPatternSize >= sizeof(PATTERN_TYPE)
417-
? getRtContext().immCmdListCopy
418-
: getRtContext().immCmdListCompute;
453+
? getRtContext().immCmdListCopy.get()
454+
: getRtContext().immCmdListCompute.get();
419455
stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
420456
ze_event_handle_t *waitEvents) {
421457
L0_SAFE_CALL(zeCommandListAppendMemoryFill(
@@ -471,7 +507,7 @@ extern "C" void mgpuLaunchKernel(ze_kernel_handle_t kernel, size_t gridX,
471507
stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
472508
ze_event_handle_t *waitEvents) {
473509
L0_SAFE_CALL(zeCommandListAppendLaunchKernel(
474-
getRtContext().immCmdListCompute, kernel, &dispatch, newEvent,
510+
getRtContext().immCmdListCompute.get(), kernel, &dispatch, newEvent,
475511
numWaitEvents, waitEvents));
476512
});
477513
}
@@ -484,7 +520,6 @@ extern "C" void mgpuSetDefaultDevice(int32_t devIdx) {
484520
catchAll([&]() {
485521
// For now, a user must ensure that streams and events complete
486522
// and are destroyed before switching a device.
487-
getRtContext().cleanup();
488523
getRtContext() = L0RtContext(devIdx);
489524
getDynamicEventPool() = DynamicEventPool(&getRtContext());
490525
});

0 commit comments

Comments
 (0)