From def5a8bac5e6b6bf1e62a109b15b2783fbc76f89 Mon Sep 17 00:00:00 2001 From: Prakhar Dixit Date: Tue, 4 Mar 2025 17:44:35 +0530 Subject: [PATCH 1/3] [mlir][vector] Add a check to ensure bailing out when reducing to a scalar, as ExtractStridedSliceOp does not support handling scalars --- mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 08ba972b12ce6..f519484fd56c8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -355,6 +355,11 @@ struct UnrollMultiReductionPattern LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { + auto resultType = reductionOp->getResult(0).getType(); + if (mlir::isa(resultType) || + mlir::isa(resultType)) { + return failure(); + } std::optional> targetShape = getTargetShape(options, reductionOp); if (!targetShape) From 4dde55a4fddaf185c831b7be944604e3459f5df9 Mon Sep 17 00:00:00 2001 From: Prakhar Dixit Date: Wed, 5 Mar 2025 10:28:26 +0530 Subject: [PATCH 2/3] Add a negative test and modify the return statement --- mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 6 +++--- mlir/test/Dialect/Vector/vector-unroll-options.mlir | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index f519484fd56c8..04c38f9f7b2e3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -356,9 +356,9 @@ struct UnrollMultiReductionPattern LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { auto resultType = reductionOp->getResult(0).getType(); - if (mlir::isa(resultType) || - mlir::isa(resultType)) { - return failure(); + if (resultType.isIntOrFloat()) { + return rewriter.notifyMatchFailure(reductionOp, + "Unrolling scalars is not supported"); } std::optional> targetShape = getTargetShape(options, reductionOp); diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 16d30aec7c041..db96e1b66a502 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -222,6 +222,13 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> // CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK: return %[[V2]] : vector<4xf32> +func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 { + %0 = vector.multi_reduction #vector.kind, %v, %acc [0, 1] : vector<4x2xf32> to f32 + return %0 : f32 +} +// CHECK-LABEL: func @negative_vector_multi_reduction +// CHECK: %[[R0:.*]] = vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32 +// CHECK: return %[[R0]] : f32 func.func @vector_reduction(%v : vector<8xf32>) -> f32 { %0 = vector.reduction , %v : vector<8xf32> into f32 From 0cf035c08caf658a20a09c864dbb7820a3a2b0dc Mon Sep 17 00:00:00 2001 From: Prakhar Dixit Date: Wed, 5 Mar 2025 14:41:07 +0530 Subject: [PATCH 3/3] modify test --- mlir/test/Dialect/Vector/vector-unroll-options.mlir | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index db96e1b66a502..9c158d05b723c 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -222,13 +222,15 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> // CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK: return %[[V2]] : vector<4xf32> +// This is a negative test case to ensure that further unrolling is not performed. Since the vector.multi_reduction +// operation has already been unrolled, attempting additional unrolling should not be allowed. func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 { %0 = vector.multi_reduction #vector.kind, %v, %acc [0, 1] : vector<4x2xf32> to f32 return %0 : f32 } // CHECK-LABEL: func @negative_vector_multi_reduction -// CHECK: %[[R0:.*]] = vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32 -// CHECK: return %[[R0]] : f32 +// CHECK-NEXT: %[[R0:.*]] = vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32 +// CHECK-NEXT: return %[[R0]] : f32 func.func @vector_reduction(%v : vector<8xf32>) -> f32 { %0 = vector.reduction , %v : vector<8xf32> into f32