diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 0a3ef7d5c9890..9c101ee0ed261 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -773,6 +773,10 @@ static DiagnosedSilenceableFailure checkMappingSpec( std::optional transformOp, scf::ForallOp forallOp, ArrayRef numParallelIterations, ArrayRef blockOrGridSizes, int factor, bool useLinearMapping = false) { + if (llvm::any_of(blockOrGridSizes, [](int64_t i) { return i <= 0; })) { + return definiteFailureHelper(transformOp, forallOp, + "block/grid sizes must be strictly positive"); + } if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) { auto diag = definiteFailureHelper( transformOp, forallOp, @@ -780,15 +784,19 @@ static DiagnosedSilenceableFailure checkMappingSpec( Twine(factor)); return diag; } - if (computeProduct(numParallelIterations) * factor > - computeProduct(blockOrGridSizes)) { + bool hasZeroParallelIteration = + llvm::any_of(numParallelIterations, [](int64_t i) { return i == 0; }); + int64_t required = hasZeroParallelIteration + ? 0 + : computeProduct(numParallelIterations) * factor; + int64_t available = computeProduct(blockOrGridSizes); + if (required > available) { auto diag = definiteFailureHelper( transformOp, forallOp, Twine("the number of required parallel resources (blocks or " "threads) ") + - Twine(computeProduct(numParallelIterations) * factor) + - " overflows the number of available resources " + - Twine(computeProduct(blockOrGridSizes))); + Twine(required) + " overflows the number of available resources " + + Twine(available)); return diag; } return DiagnosedSilenceableFailure::success(); diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 8d3944f883963..7070acc9db28f 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -141,12 +141,14 @@ std::optional getConstantIntValue(OpFoldResult ofr) { std::optional> getConstantIntValues(ArrayRef ofrs) { bool failed = false; - SmallVector res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) { + SmallVector res; + res.reserve(ofrs.size()); + for (OpFoldResult ofr : ofrs) { auto cv = getConstantIntValue(ofr); if (!cv.has_value()) failed = true; - return cv.value_or(0); - }); + res.push_back(cv.value_or(0)); + } if (failed) return std::nullopt; return res; diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir index bc052a0230a8e..98f20f0d50765 100644 --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -512,3 +512,33 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @map_nested_forall_to_threads_invalid_block_dims(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> { + %one = arith.constant 1 : index + %c900 = arith.constant 900 : index + %c7 = arith.constant 7 : index + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.forall (%i, %j) in (%c7, %c900) { + %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> + %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> + } { mapping = [#gpu.thread, #gpu.thread] } + gpu.terminator + } + return %y : memref<2 x 32 x f32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{block/grid sizes must be strictly positive}} + transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 0, 1] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} +