@@ -45,7 +45,6 @@ auto catchAll(F &&func) {
4545 std::abort (); \
4646 } \
4747 }
48-
4948} // namespace
5049
5150// ===----------------------------------------------------------------------===//
@@ -118,23 +117,49 @@ static ze_context_handle_t getDefaultContext() {
118117// L0 RT helper structs
119118// ===----------------------------------------------------------------------===//
120119
120+ struct ZeContextDeleter {
121+ void operator ()(ze_context_handle_t ctx) const {
122+ if (ctx)
123+ L0_SAFE_CALL (zeContextDestroy (ctx));
124+ }
125+ };
126+
127+ struct ZeCommandListDeleter {
128+ void operator ()(ze_command_list_handle_t cmdList) const {
129+ if (cmdList)
130+ L0_SAFE_CALL (zeCommandListDestroy (cmdList));
131+ }
132+ };
133+
121134struct L0RtContext {
122135 ze_driver_handle_t driver{nullptr };
123136 ze_device_handle_t device{nullptr };
124- ze_context_handle_t context{nullptr };
137+ using UniqueZeContext =
138+ std::unique_ptr<std::remove_pointer<ze_context_handle_t >::type,
139+ ZeContextDeleter>;
140+ UniqueZeContext context;
141+
125142 // Usually, one immediate command list with ordinal 0 suffices for
126143 // both copy and compute ops, but leaves HW underutilized.
127- ze_command_list_handle_t immCmdListCompute{nullptr };
144+ using UniqueZeCommandList =
145+ std::unique_ptr<std::remove_pointer<ze_command_list_handle_t >::type,
146+ ZeCommandListDeleter>;
147+ UniqueZeCommandList immCmdListCompute;
128148 // Copy engines can be used for both memcpy and memset, but
129149 // they have limitations for memset pattern size (e.g., 1 byte).
130- ze_command_list_handle_t immCmdListCopy{ nullptr } ;
150+ UniqueZeCommandList immCmdListCopy;
131151 uint32_t copyEngineMaxMemoryFillPatternSize{-1u };
132152
153+ L0RtContext () = default ;
133154 L0RtContext (const int32_t devIdx = 0 )
134- : driver(getDriver()), device(getDefaultDevice(devIdx)),
135- context (getDefaultContext()) {
155+ : driver(getDriver()), device(getDefaultDevice(devIdx)) {
156+ // Create context
157+ ze_context_handle_t defaultCtx = getDefaultContext ();
158+ context.reset (defaultCtx);
159+
160+ // Determine ordinals
136161 uint32_t computeEngineOrdinal = -1u , copyEngineOrdinal = -1u ;
137- ze_device_properties_t deviceProperties = {};
162+ ze_device_properties_t deviceProperties{};
138163 L0_SAFE_CALL (zeDeviceGetProperties (device, &deviceProperties));
139164 uint32_t queueGroupCount = 0 ;
140165 L0_SAFE_CALL (zeDeviceGetCommandQueueGroupProperties (
@@ -143,6 +168,7 @@ struct L0RtContext {
143168 queueGroupCount);
144169 L0_SAFE_CALL (zeDeviceGetCommandQueueGroupProperties (
145170 device, &queueGroupCount, queueGroupProperties.data ()));
171+
146172 for (uint32_t queueGroupIdx = 0 ; queueGroupIdx < queueGroupCount;
147173 ++queueGroupIdx) {
148174 const auto &group = queueGroupProperties[queueGroupIdx];
@@ -155,11 +181,15 @@ struct L0RtContext {
155181 if (copyEngineOrdinal != -1u && computeEngineOrdinal != -1u )
156182 break ;
157183 }
184+
158185 // Fallback to the default queue if no dedicated copy queue is available.
159186 if (copyEngineOrdinal == -1u )
160187 copyEngineOrdinal = computeEngineOrdinal;
188+
161189 assert (copyEngineOrdinal != -1u && computeEngineOrdinal != -1u &&
162190 " Expected two engines to be available." );
191+
192+ // Create copy command list
163193 ze_command_queue_desc_t cmdQueueDesc{
164194 ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
165195 nullptr ,
@@ -168,18 +198,25 @@ struct L0RtContext {
168198 0 , // flags
169199 ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,
170200 ZE_COMMAND_QUEUE_PRIORITY_NORMAL};
171- L0_SAFE_CALL (zeCommandListCreateImmediate (context, device, &cmdQueueDesc,
172- &immCmdListCopy));
201+
202+ ze_command_list_handle_t rawCmdListCopy = nullptr ;
203+ L0_SAFE_CALL (zeCommandListCreateImmediate (context.get (), device,
204+ &cmdQueueDesc, &rawCmdListCopy));
205+ immCmdListCopy.reset (rawCmdListCopy);
206+
207+ // Create compute command list
173208 cmdQueueDesc.ordinal = computeEngineOrdinal;
174- L0_SAFE_CALL (zeCommandListCreateImmediate (context, device, &cmdQueueDesc,
175- &immCmdListCompute));
176- }
177- void cleanup () {
178- L0_SAFE_CALL (zeCommandListDestroy (immCmdListCopy));
179- L0_SAFE_CALL (zeCommandListDestroy (immCmdListCompute));
180- L0_SAFE_CALL (zeContextDestroy (context));
209+ ze_command_list_handle_t rawCmdListCompute = nullptr ;
210+ L0_SAFE_CALL (zeCommandListCreateImmediate (
211+ context.get (), device, &cmdQueueDesc, &rawCmdListCompute));
212+ immCmdListCompute.reset (rawCmdListCompute);
181213 }
182- ~L0RtContext () { cleanup (); }
214+ L0RtContext (const L0RtContext &) = delete ;
215+ L0RtContext &operator =(const L0RtContext &) = delete ;
216+ // Allow move
217+ L0RtContext (L0RtContext &&) noexcept = default ;
218+ L0RtContext &operator =(L0RtContext &&) noexcept = default ;
219+ ~L0RtContext () = default ;
183220};
184221
185222// L0 only supports pre-determined sizes of event pools,
@@ -212,7 +249,7 @@ struct DynamicEventPool {
212249 eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE;
213250 eventPoolDesc.count = numEvents;
214251 eventPools.push_back (nullptr );
215- L0_SAFE_CALL (zeEventPoolCreate (rtCtx->context , &eventPoolDesc, 1 ,
252+ L0_SAFE_CALL (zeEventPoolCreate (rtCtx->context . get () , &eventPoolDesc, 1 ,
216253 &rtCtx->device , &eventPools.back ()));
217254 currentEventsLimit += numEvents;
218255 }
@@ -246,7 +283,7 @@ struct DynamicEventPool {
246283};
247284
248285L0RtContext &getRtContext () {
249- thread_local static L0RtContext rtContext;
286+ thread_local static L0RtContext rtContext ( 0 ) ;
250287 return rtContext;
251288}
252289
@@ -286,13 +323,13 @@ struct StreamWrapper {
286323 implicitEventStack.clear ();
287324 }
288325
289- void enqueueOp (
290- std::function<void (ze_event_handle_t , uint32_t , ze_event_handle_t *)>
291- op) {
326+ template <typename Func>
327+ void enqueueOp (Func &&op) {
292328 ze_event_handle_t newImplicitEvent = dynEventPool.takeEvent ();
293329 ze_event_handle_t *lastImplicitEventPtr = getLastImplicitEventPtr ();
294330 const uint32_t numWaitEvents = lastImplicitEventPtr ? 1 : 0 ;
295- op (newImplicitEvent, numWaitEvents, lastImplicitEventPtr);
331+ std::forward<Func>(op)(newImplicitEvent, numWaitEvents,
332+ lastImplicitEventPtr);
296333 implicitEventStack.push_back (newImplicitEvent);
297334 }
298335};
@@ -309,7 +346,7 @@ static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
309346 nullptr };
310347 ze_module_build_log_handle_t buildLogHandle;
311348 ze_result_t result =
312- zeModuleCreate (getRtContext ().context , getRtContext ().device , &desc,
349+ zeModuleCreate (getRtContext ().context . get () , getRtContext ().device , &desc,
313350 &zeModule, &buildLogHandle);
314351 if (result != ZE_RESULT_SUCCESS) {
315352 std::cerr << " Error creating module, error code: " << result << std::endl;
@@ -337,14 +374,12 @@ extern "C" void mgpuStreamSynchronize(StreamWrapper *stream) {
337374 stream->sync ();
338375}
339376
340- extern " C" void mgpuStreamDestroy (StreamWrapper *stream) {
341- if (stream)
342- delete stream;
343- }
377+ extern " C" void mgpuStreamDestroy (StreamWrapper *stream) { delete stream; }
344378
345379extern " C" void mgpuStreamWaitEvent (StreamWrapper *stream,
346380 ze_event_handle_t event) {
347- assert (stream && event);
381+ assert (stream && " Invalid stream" );
382+ assert (event && " Invalid event" );
348383 stream->sync (event);
349384}
350385
@@ -364,10 +399,10 @@ extern "C" void mgpuEventSynchronize(ze_event_handle_t event) {
364399
365400extern " C" void mgpuEventRecord (ze_event_handle_t event,
366401 StreamWrapper *stream) {
367- L0_SAFE_CALL (
368- zeCommandListAppendSignalEvent ( getRtContext ().immCmdListCopy , event));
369- L0_SAFE_CALL (
370- zeCommandListAppendSignalEvent ( getRtContext ().immCmdListCompute , event));
402+ L0_SAFE_CALL (zeCommandListAppendSignalEvent (
403+ getRtContext ().immCmdListCopy . get () , event));
404+ L0_SAFE_CALL (zeCommandListAppendSignalEvent (
405+ getRtContext ().immCmdListCompute . get () , event));
371406}
372407
373408extern " C" void *mgpuMemAlloc (uint64_t size, StreamWrapper *stream,
@@ -380,12 +415,13 @@ extern "C" void *mgpuMemAlloc(uint64_t size, StreamWrapper *stream,
380415 if (isShared) {
381416 ze_host_mem_alloc_desc_t hostDesc = {};
382417 hostDesc.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC;
383- L0_SAFE_CALL (zeMemAllocShared (getRtContext ().context , &deviceDesc,
418+ L0_SAFE_CALL (zeMemAllocShared (getRtContext ().context . get () , &deviceDesc,
384419 &hostDesc, size, alignment,
385420 getRtContext ().device , &memPtr));
386421 } else {
387- L0_SAFE_CALL (zeMemAllocDevice (getRtContext ().context , &deviceDesc, size,
388- alignment, getRtContext ().device , &memPtr));
422+ L0_SAFE_CALL (zeMemAllocDevice (getRtContext ().context .get (), &deviceDesc,
423+ size, alignment, getRtContext ().device ,
424+ &memPtr));
389425 }
390426 if (!memPtr)
391427 throw std::runtime_error (" mem allocation failed!" );
@@ -396,16 +432,16 @@ extern "C" void *mgpuMemAlloc(uint64_t size, StreamWrapper *stream,
396432extern " C" void mgpuMemFree (void *ptr, StreamWrapper *stream) {
397433 stream->sync ();
398434 if (ptr)
399- L0_SAFE_CALL (zeMemFree (getRtContext ().context , ptr));
435+ L0_SAFE_CALL (zeMemFree (getRtContext ().context . get () , ptr));
400436}
401437
402438extern " C" void mgpuMemcpy (void *dst, void *src, size_t sizeBytes,
403439 StreamWrapper *stream) {
404440 stream->enqueueOp ([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
405441 ze_event_handle_t *waitEvents) {
406- L0_SAFE_CALL (zeCommandListAppendMemoryCopy (getRtContext (). immCmdListCopy ,
407- dst, src, sizeBytes, newEvent,
408- numWaitEvents, waitEvents));
442+ L0_SAFE_CALL (zeCommandListAppendMemoryCopy (
443+ getRtContext (). immCmdListCopy . get (), dst, src, sizeBytes, newEvent,
444+ numWaitEvents, waitEvents));
409445 });
410446}
411447
@@ -414,8 +450,8 @@ void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count,
414450 StreamWrapper *stream) {
415451 auto listType =
416452 getRtContext ().copyEngineMaxMemoryFillPatternSize >= sizeof (PATTERN_TYPE)
417- ? getRtContext ().immCmdListCopy
418- : getRtContext ().immCmdListCompute ;
453+ ? getRtContext ().immCmdListCopy . get ()
454+ : getRtContext ().immCmdListCompute . get () ;
419455 stream->enqueueOp ([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
420456 ze_event_handle_t *waitEvents) {
421457 L0_SAFE_CALL (zeCommandListAppendMemoryFill (
@@ -471,7 +507,7 @@ extern "C" void mgpuLaunchKernel(ze_kernel_handle_t kernel, size_t gridX,
471507 stream->enqueueOp ([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
472508 ze_event_handle_t *waitEvents) {
473509 L0_SAFE_CALL (zeCommandListAppendLaunchKernel (
474- getRtContext ().immCmdListCompute , kernel, &dispatch, newEvent,
510+ getRtContext ().immCmdListCompute . get () , kernel, &dispatch, newEvent,
475511 numWaitEvents, waitEvents));
476512 });
477513}
@@ -484,7 +520,6 @@ extern "C" void mgpuSetDefaultDevice(int32_t devIdx) {
484520 catchAll ([&]() {
485521 // For now, a user must ensure that streams and events complete
486522 // and are destroyed before switching a device.
487- getRtContext ().cleanup ();
488523 getRtContext () = L0RtContext (devIdx);
489524 getDynamicEventPool () = DynamicEventPool (&getRtContext ());
490525 });
0 commit comments