@@ -271,13 +271,264 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
271271}
272272
273273UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
274- ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
275- const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
276- const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
277- const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
278- return urEnqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
279- pGlobalWorkSize, pLocalWorkSize,
280- numEventsInWaitList, phEventWaitList, phEvent);
274+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
275+ ur_kernel_handle_t Kernel, // /< [in] handle of the kernel object
276+ uint32_t WorkDim, // /< [in] number of dimensions, from 1 to 3, to specify
277+ // /< the global and work-group work-items
278+ const size_t
279+ *GlobalWorkOffset, // /< [in] pointer to an array of workDim unsigned
280+ // /< values that specify the offset used to
281+ // /< calculate the global ID of a work-item
282+ const size_t *GlobalWorkSize, // /< [in] pointer to an array of workDim
283+ // /< unsigned values that specify the number
284+ // /< of global work-items in workDim that
285+ // /< will execute the kernel function
286+ const size_t
287+ *LocalWorkSize, // /< [in][optional] pointer to an array of workDim
288+ // /< unsigned values that specify the number of local
289+ // /< work-items forming a work-group that will execute
290+ // /< the kernel function. If nullptr, the runtime
291+ // /< implementation will choose the work-group size.
292+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
293+ const ur_event_handle_t
294+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
295+ // /< pointer to a list of events that must be complete
296+ // /< before the kernel execution. If nullptr, the
297+ // /< numEventsInWaitList must be 0, indicating that no
298+ // /< wait event.
299+ ur_event_handle_t
300+ *OutEvent // /< [in,out][optional] return an event object that identifies
301+ // /< this particular kernel execution instance.
302+ ) {
303+ auto ZeDevice = Queue->Device ->ZeDevice ;
304+
305+ ze_kernel_handle_t ZeKernel{};
306+ if (Kernel->ZeKernelMap .empty ()) {
307+ ZeKernel = Kernel->ZeKernel ;
308+ } else {
309+ auto It = Kernel->ZeKernelMap .find (ZeDevice);
310+ if (It == Kernel->ZeKernelMap .end ()) {
311+ /* kernel and queue don't match */
312+ return UR_RESULT_ERROR_INVALID_QUEUE;
313+ }
314+ ZeKernel = It->second ;
315+ }
316+ // Lock automatically releases when this goes out of scope.
317+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
318+ Queue->Mutex , Kernel->Mutex , Kernel->Program ->Mutex );
319+ if (GlobalWorkOffset != NULL ) {
320+ if (!Queue->Device ->Platform ->ZeDriverGlobalOffsetExtensionFound ) {
321+ logger::error (" No global offset extension found on this driver" );
322+ return UR_RESULT_ERROR_INVALID_VALUE;
323+ }
324+
325+ ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
326+ (ZeKernel, GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
327+ GlobalWorkOffset[2 ]));
328+ }
329+
330+ // If there are any pending arguments set them now.
331+ for (auto &Arg : Kernel->PendingArguments ) {
332+ // The ArgValue may be a NULL pointer in which case a NULL value is used for
333+ // the kernel argument declared as a pointer to global or constant memory.
334+ char **ZeHandlePtr = nullptr ;
335+ if (Arg.Value ) {
336+ UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode ,
337+ Queue->Device ));
338+ }
339+ ZE2UR_CALL (zeKernelSetArgumentValue,
340+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
341+ }
342+ Kernel->PendingArguments .clear ();
343+
344+ ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
345+ uint32_t WG[3 ]{};
346+
347+ // New variable needed because GlobalWorkSize parameter might not be of size 3
348+ size_t GlobalWorkSize3D[3 ]{1 , 1 , 1 };
349+ std::copy (GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);
350+
351+ if (LocalWorkSize) {
352+ // L0
353+ UR_ASSERT (LocalWorkSize[0 ] < (std::numeric_limits<uint32_t >::max)(),
354+ UR_RESULT_ERROR_INVALID_VALUE);
355+ UR_ASSERT (LocalWorkSize[1 ] < (std::numeric_limits<uint32_t >::max)(),
356+ UR_RESULT_ERROR_INVALID_VALUE);
357+ UR_ASSERT (LocalWorkSize[2 ] < (std::numeric_limits<uint32_t >::max)(),
358+ UR_RESULT_ERROR_INVALID_VALUE);
359+ WG[0 ] = static_cast <uint32_t >(LocalWorkSize[0 ]);
360+ WG[1 ] = static_cast <uint32_t >(LocalWorkSize[1 ]);
361+ WG[2 ] = static_cast <uint32_t >(LocalWorkSize[2 ]);
362+ } else {
363+ // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize
364+ // values do not fit to 32-bit that the API only supports currently.
365+ bool SuggestGroupSize = true ;
366+ for (int I : {0 , 1 , 2 }) {
367+ if (GlobalWorkSize3D[I] > UINT32_MAX) {
368+ SuggestGroupSize = false ;
369+ }
370+ }
371+ if (SuggestGroupSize) {
372+ ZE2UR_CALL (zeKernelSuggestGroupSize,
373+ (ZeKernel, GlobalWorkSize3D[0 ], GlobalWorkSize3D[1 ],
374+ GlobalWorkSize3D[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
375+ } else {
376+ for (int I : {0 , 1 , 2 }) {
377+ // Try to find a I-dimension WG size that the GlobalWorkSize[I] is
378+ // fully divisable with. Start with the max possible size in
379+ // each dimension.
380+ uint32_t GroupSize[] = {
381+ Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeX ,
382+ Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeY ,
383+ Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeZ };
384+ GroupSize[I] = (std::min)(size_t (GroupSize[I]), GlobalWorkSize3D[I]);
385+ while (GlobalWorkSize3D[I] % GroupSize[I]) {
386+ --GroupSize[I];
387+ }
388+
389+ if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) {
390+ logger::error (
391+ " urEnqueueCooperativeKernelLaunchExp: can't find a WG size "
392+ " suitable for global work size > UINT32_MAX" );
393+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
394+ }
395+ WG[I] = GroupSize[I];
396+ }
397+ logger::debug (" urEnqueueCooperativeKernelLaunchExp: using computed WG "
398+ " size = {{{}, {}, {}}}" ,
399+ WG[0 ], WG[1 ], WG[2 ]);
400+ }
401+ }
402+
403+ // TODO: assert if sizes do not fit into 32-bit?
404+
405+ switch (WorkDim) {
406+ case 3 :
407+ ZeThreadGroupDimensions.groupCountX =
408+ static_cast <uint32_t >(GlobalWorkSize3D[0 ] / WG[0 ]);
409+ ZeThreadGroupDimensions.groupCountY =
410+ static_cast <uint32_t >(GlobalWorkSize3D[1 ] / WG[1 ]);
411+ ZeThreadGroupDimensions.groupCountZ =
412+ static_cast <uint32_t >(GlobalWorkSize3D[2 ] / WG[2 ]);
413+ break ;
414+ case 2 :
415+ ZeThreadGroupDimensions.groupCountX =
416+ static_cast <uint32_t >(GlobalWorkSize3D[0 ] / WG[0 ]);
417+ ZeThreadGroupDimensions.groupCountY =
418+ static_cast <uint32_t >(GlobalWorkSize3D[1 ] / WG[1 ]);
419+ WG[2 ] = 1 ;
420+ break ;
421+ case 1 :
422+ ZeThreadGroupDimensions.groupCountX =
423+ static_cast <uint32_t >(GlobalWorkSize3D[0 ] / WG[0 ]);
424+ WG[1 ] = WG[2 ] = 1 ;
425+ break ;
426+
427+ default :
428+ logger::error (" urEnqueueCooperativeKernelLaunchExp: unsupported work_dim" );
429+ return UR_RESULT_ERROR_INVALID_VALUE;
430+ }
431+
432+ // Error handling for non-uniform group size case
433+ if (GlobalWorkSize3D[0 ] !=
434+ size_t (ZeThreadGroupDimensions.groupCountX ) * WG[0 ]) {
435+ logger::error (" urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
436+ " range is not a "
437+ " multiple of the group size in the 1st dimension" );
438+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
439+ }
440+ if (GlobalWorkSize3D[1 ] !=
441+ size_t (ZeThreadGroupDimensions.groupCountY ) * WG[1 ]) {
442+ logger::error (" urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
443+ " range is not a "
444+ " multiple of the group size in the 2nd dimension" );
445+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
446+ }
447+ if (GlobalWorkSize3D[2 ] !=
448+ size_t (ZeThreadGroupDimensions.groupCountZ ) * WG[2 ]) {
449+ logger::debug (" urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
450+ " range is not a "
451+ " multiple of the group size in the 3rd dimension" );
452+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
453+ }
454+
455+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
456+
457+ bool UseCopyEngine = false ;
458+ _ur_ze_event_list_t TmpWaitList;
459+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
460+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
461+
462+ // Get a new command list to be used on this call
463+ ur_command_list_ptr_t CommandList{};
464+ UR_CALL (Queue->Context ->getAvailableCommandList (
465+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
466+ true /* AllowBatching */ ));
467+
468+ ze_event_handle_t ZeEvent = nullptr ;
469+ ur_event_handle_t InternalEvent{};
470+ bool IsInternal = OutEvent == nullptr ;
471+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
472+
473+ UR_CALL (createEventAndAssociateQueue (Queue, Event, UR_COMMAND_KERNEL_LAUNCH,
474+ CommandList, IsInternal, false ));
475+ UR_CALL (setSignalEvent (Queue, UseCopyEngine, &ZeEvent, Event,
476+ NumEventsInWaitList, EventWaitList,
477+ CommandList->second .ZeQueue ));
478+ (*Event)->WaitList = TmpWaitList;
479+
480+ // Save the kernel in the event, so that when the event is signalled
481+ // the code can do a urKernelRelease on this kernel.
482+ (*Event)->CommandData = (void *)Kernel;
483+
484+ // Increment the reference count of the Kernel and indicate that the Kernel
485+ // is in use. Once the event has been signalled, the code in
486+ // CleanupCompletedEvent(Event) will do a urKernelRelease to update the
487+ // reference count on the kernel, using the kernel saved in CommandData.
488+ UR_CALL (urKernelRetain (Kernel));
489+
490+ // Add to list of kernels to be submitted
491+ if (IndirectAccessTrackingEnabled)
492+ Queue->KernelsToBeSubmitted .push_back (Kernel);
493+
494+ if (Queue->UsingImmCmdLists && IndirectAccessTrackingEnabled) {
495+ // If using immediate commandlists then gathering of indirect
496+ // references and appending to the queue (which means submission)
497+ // must be done together.
498+ std::unique_lock<ur_shared_mutex> ContextsLock (
499+ Queue->Device ->Platform ->ContextsMutex , std::defer_lock);
500+ // We are going to submit kernels for execution. If indirect access flag is
501+ // set for a kernel then we need to make a snapshot of existing memory
502+ // allocations in all contexts in the platform. We need to lock the mutex
503+ // guarding the list of contexts in the platform to prevent creation of new
504+ // memory alocations in any context before we submit the kernel for
505+ // execution.
506+ ContextsLock.lock ();
507+ Queue->CaptureIndirectAccesses ();
508+ // Add the command to the command list, which implies submission.
509+ ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
510+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
511+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
512+ } else {
513+ // Add the command to the command list for later submission.
514+ // No lock is needed here, unlike the immediate commandlist case above,
515+ // because the kernels are not actually submitted yet. Kernels will be
516+ // submitted only when the comamndlist is closed. Then, a lock is held.
517+ ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
518+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
519+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
520+ }
521+
522+ logger::debug (" calling zeCommandListAppendLaunchCooperativeKernel() with"
523+ " ZeEvent {}" ,
524+ ur_cast<std::uintptr_t >(ZeEvent));
525+ printZeEventList ((*Event)->WaitList );
526+
527+ // Execute command list asynchronously, as the event will be used
528+ // to track down its completion.
529+ UR_CALL (Queue->executeCommandList (CommandList, false , true ));
530+
531+ return UR_RESULT_SUCCESS;
281532}
282533
283534UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite (
@@ -829,10 +1080,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
8291080UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp (
8301081 ur_kernel_handle_t hKernel, size_t localWorkSize,
8311082 size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
832- (void )hKernel;
8331083 (void )localWorkSize;
8341084 (void )dynamicSharedMemorySize;
835- *pGroupCountRet = 1 ;
1085+ std::shared_lock<ur_shared_mutex> Guard (hKernel->Mutex );
1086+ uint32_t TotalGroupCount = 0 ;
1087+ ZE2UR_CALL (zeKernelSuggestMaxCooperativeGroupCount,
1088+ (hKernel->ZeKernel , &TotalGroupCount));
1089+ *pGroupCountRet = TotalGroupCount;
8361090 return UR_RESULT_SUCCESS;
8371091}
8381092
0 commit comments