diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 08ba972b12ce6..04c38f9f7b2e3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -355,6 +355,11 @@ struct UnrollMultiReductionPattern LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { + auto resultType = reductionOp->getResult(0).getType(); + if (resultType.isIntOrFloat()) { + return rewriter.notifyMatchFailure(reductionOp, + "Unrolling scalars is not supported"); + } std::optional> targetShape = getTargetShape(options, reductionOp); if (!targetShape) diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 16d30aec7c041..9c158d05b723c 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -222,6 +222,15 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> // CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK: return %[[V2]] : vector<4xf32> +// This is a negative test case to ensure that further unrolling is not performed. Since the vector.multi_reduction +// operation has already been unrolled, attempting additional unrolling should not be allowed. +func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 { + %0 = vector.multi_reduction #vector.kind, %v, %acc [0, 1] : vector<4x2xf32> to f32 + return %0 : f32 +} +// CHECK-LABEL: func @negative_vector_multi_reduction +// CHECK-NEXT: %[[R0:.*]] = vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32 +// CHECK-NEXT: return %[[R0]] : f32 func.func @vector_reduction(%v : vector<8xf32>) -> f32 { %0 = vector.reduction , %v : vector<8xf32> into f32