@@ -649,4 +649,24 @@ ur_result_t urKernelGetSuggestedLocalWorkSize(
649649 std::copy (localWorkSize, localWorkSize + workDim, pSuggestedLocalWorkSize);
650650 return UR_RESULT_SUCCESS;
651651}
652+
653+ ur_result_t urKernelSuggestMaxCooperativeGroupCountExp (
654+ ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
655+ const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
656+ uint32_t *pGroupCountRet) {
657+ (void )dynamicSharedMemorySize;
658+
659+ uint32_t wg[3 ];
660+ wg[0 ] = ur_cast<uint32_t >(pLocalWorkSize[0 ]);
661+ wg[1 ] = workDim >= 2 ? ur_cast<uint32_t >(pLocalWorkSize[1 ]) : 1 ;
662+ wg[2 ] = workDim == 3 ? ur_cast<uint32_t >(pLocalWorkSize[2 ]) : 1 ;
663+ ZE2UR_CALL (zeKernelSetGroupSize,
664+ (hKernel->getZeHandle (hDevice), wg[0 ], wg[1 ], wg[2 ]));
665+
666+ uint32_t totalGroupCount = 0 ;
667+ ZE2UR_CALL (zeKernelSuggestMaxCooperativeGroupCount,
668+ (hKernel->getZeHandle (hDevice), &totalGroupCount));
669+ *pGroupCountRet = totalGroupCount;
670+ return UR_RESULT_SUCCESS;
671+ }
652672} // namespace ur::level_zero
0 commit comments