diff --git a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp index a8c75e41e44d..fb84f00a3734 100644 --- a/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp +++ b/unified-runtime/source/adapters/level_zero/helpers/kernel_helpers.cpp @@ -96,36 +96,37 @@ ur_result_t calculateKernelWorkDimensions( // New variable needed because GlobalWorkSize parameter might not be of size // 3 size_t GlobalWorkSize3D[3]{1, 1, 1}; - std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); if (LocalWorkSize) { WG[0] = ur_cast(LocalWorkSize[0]); WG[1] = WorkDim >= 2 ? ur_cast(LocalWorkSize[1]) : 1; WG[2] = WorkDim == 3 ? ur_cast(LocalWorkSize[2]) : 1; } else { + std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); UR_CALL(getSuggestedLocalWorkSize(Device, Kernel, GlobalWorkSize3D, WG)); } + const size_t *GlobalWorkSizePtr = LocalWorkSize ? GlobalWorkSize : GlobalWorkSize3D; // TODO: assert if sizes do not fit into 32-bit? switch (WorkDim) { case 3: ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize3D[0] / WG[0]); + ur_cast(GlobalWorkSizePtr[0] / WG[0]); ZeThreadGroupDimensions.groupCountY = - ur_cast(GlobalWorkSize3D[1] / WG[1]); + ur_cast(GlobalWorkSizePtr[1] / WG[1]); ZeThreadGroupDimensions.groupCountZ = - ur_cast(GlobalWorkSize3D[2] / WG[2]); + ur_cast(GlobalWorkSizePtr[2] / WG[2]); break; case 2: ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize3D[0] / WG[0]); + ur_cast(GlobalWorkSizePtr[0] / WG[0]); ZeThreadGroupDimensions.groupCountY = - ur_cast(GlobalWorkSize3D[1] / WG[1]); + ur_cast(GlobalWorkSizePtr[1] / WG[1]); WG[2] = 1; break; case 1: ZeThreadGroupDimensions.groupCountX = - ur_cast(GlobalWorkSize3D[0] / WG[0]); + ur_cast(GlobalWorkSizePtr[0] / WG[0]); WG[1] = WG[2] = 1; break; @@ -135,19 +136,19 @@ ur_result_t calculateKernelWorkDimensions( } // Error handling for non-uniform group size case - if (GlobalWorkSize3D[0] != + if (GlobalWorkSizePtr[0] != size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { UR_LOG(ERR, "calculateKernelWorkDimensions: invalid work_dim. The range " "is not a multiple of the group size in the 1st dimension"); return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; } - if (GlobalWorkSize3D[1] != + if (WorkDim >= 2 && GlobalWorkSizePtr[1] != size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { UR_LOG(ERR, "calculateKernelWorkDimensions: invalid work_dim. The range " "is not a multiple of the group size in the 2nd dimension"); return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; } - if (GlobalWorkSize3D[2] != + if (WorkDim == 3 && GlobalWorkSizePtr[2] != size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { UR_LOG(ERR, "calculateKernelWorkDimensions: invalid work_dim. The range " "is not a multiple of the group size in the 3rd dimension");