@@ -174,15 +174,36 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
174174 ur_device_handle_t Device = hKernel->getProgram ()->getDevice ();
175175 ScopedContext Active (Device);
176176 try {
177+ // We need to calculate max num of work-groups using per-device semantics.
178+
177179 int MaxNumActiveGroupsPerCU{0 };
178180 UR_CHECK_ERROR (cuOccupancyMaxActiveBlocksPerMultiprocessor (
179181 &MaxNumActiveGroupsPerCU, hKernel->get (), localWorkSize,
180182 dynamicSharedMemorySize));
181183 detail::ur::assertion (MaxNumActiveGroupsPerCU >= 0 );
182-
183- // Multiply by the number of SMs (CUs = compute units) on the device in
184- // order to retreive the total number of groups/blocks that can be launched.
185- *pGroupCountRet = Device->getNumComputeUnits () * MaxNumActiveGroupsPerCU;
184+ // Handle the case where we can't have all SMs active with at least 1 group
185+ // per SM. In that case, the device is still able to run 1 work-group, hence
186+ // we will manually check if it is possible with the available HW resources.
187+ if (MaxNumActiveGroupsPerCU == 0 ) {
188+ size_t MaxWorkGroupSize{};
189+ urKernelGetGroupInfo (
190+ hKernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE,
191+ sizeof (MaxWorkGroupSize), &MaxWorkGroupSize, nullptr );
192+ size_t MaxLocalSizeBytes{};
193+ urDeviceGetInfo (Device, UR_DEVICE_INFO_LOCAL_MEM_SIZE,
194+ sizeof (MaxLocalSizeBytes), &MaxLocalSizeBytes, nullptr );
195+ if (localWorkSize > MaxWorkGroupSize ||
196+ dynamicSharedMemorySize > MaxLocalSizeBytes ||
197+ hasExceededMaxRegistersPerBlock (Device, hKernel, localWorkSize))
198+ *pGroupCountRet = 0 ;
199+ else
200+ *pGroupCountRet = 1 ;
201+ } else {
202+ // Multiply by the number of SMs (CUs = compute units) on the device in
203+ // order to retreive the total number of groups/blocks that can be
204+ // launched.
205+ *pGroupCountRet = Device->getNumComputeUnits () * MaxNumActiveGroupsPerCU;
206+ }
186207 } catch (ur_result_t Err) {
187208 return Err;
188209 }
0 commit comments