@@ -1054,8 +1054,9 @@ ur_result_t urKernelGetNativeHandle(
10541054}
10551055
10561056ur_result_t urKernelSuggestMaxCooperativeGroupCountExp (
1057- ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
1058- size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
1057+ ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
1058+ const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
1059+ uint32_t *pGroupCountRet) {
10591060 (void )dynamicSharedMemorySize;
10601061 std::shared_lock<ur_shared_mutex> Guard (hKernel->Mutex );
10611062
@@ -1066,8 +1067,10 @@ ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
10661067 ZE2UR_CALL (zeKernelSetGroupSize, (hKernel->ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
10671068
10681069 uint32_t TotalGroupCount = 0 ;
1070+ ze_kernel_handle_t ZeKernel;
1071+ UR_CALL (getZeKernel (hDevice->ZeDevice , hKernel, &ZeKernel));
10691072 ZE2UR_CALL (zeKernelSuggestMaxCooperativeGroupCount,
1070- (hKernel-> ZeKernel , &TotalGroupCount));
1073+ (ZeKernel, &TotalGroupCount));
10711074 *pGroupCountRet = TotalGroupCount;
10721075 return UR_RESULT_SUCCESS;
10731076}
0 commit comments