@@ -167,10 +167,46 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
167167UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp (
168168 ur_kernel_handle_t hKernel, size_t localWorkSize,
169169 size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
170- (void )hKernel;
171- (void )localWorkSize;
172- (void )dynamicSharedMemorySize;
173- *pGroupCountRet = 1 ;
170+ UR_ASSERT (hKernel, UR_RESULT_ERROR_INVALID_KERNEL);
171+
172+ // We need to set the active current device for this kernel explicitly here,
173+ // because the occupancy querying API does not take device parameter.
174+ ur_device_handle_t Device = hKernel->getProgram ()->getDevice ();
175+ ScopedContext Active (Device);
176+ try {
177+ // We need to calculate max num of work-groups using per-device semantics.
178+
179+ int MaxNumActiveGroupsPerCU{0 };
180+ UR_CHECK_ERROR (cuOccupancyMaxActiveBlocksPerMultiprocessor (
181+ &MaxNumActiveGroupsPerCU, hKernel->get (), localWorkSize,
182+ dynamicSharedMemorySize));
183+ detail::ur::assertion (MaxNumActiveGroupsPerCU >= 0 );
184+ // Handle the case where we can't have all SMs active with at least 1 group
185+ // per SM. In that case, the device is still able to run 1 work-group, hence
186+ // we will manually check if it is possible with the available HW resources.
187+ if (MaxNumActiveGroupsPerCU == 0 ) {
188+ size_t MaxWorkGroupSize{};
189+ urKernelGetGroupInfo (
190+ hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE,
191+ sizeof (MaxWorkGroupSize), &MaxWorkGroupSize, nullptr );
192+ size_t MaxLocalSizeBytes{};
193+ urDeviceGetInfo (Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE,
194+ sizeof (MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr );
195+ if (localWorkSize > MaxWorkGroupSize ||
196+ dynamicSharedMemorySize > MaxLocalSizeBytes ||
197+ hasExceededMaxRegistersPerBlock (Device, hKernel, localWorkSize))
198+ *pGroupCountRet = 0 ;
199+ else
200+ *pGroupCountRet = 1 ;
201+ } else {
202+ // Multiply by the number of SMs (CUs = compute units) on the device in
203+ // order to retreive the total number of groups/blocks that can be
204+ // launched.
205+ *pGroupCountRet = Device->getNumComputeUnits () * MaxNumActiveGroupsPerCU;
206+ }
207+ } catch (ur_result_t Err) {
208+ return Err;
209+ }
174210 return UR_RESULT_SUCCESS;
175211}
176212
0 commit comments