-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][Vector] Canonicalize empty vector.mask into arith.select
#140976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Canonicalize empty vector.mask into arith.select
#140976
Conversation
|
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis MR adds a missing canonicalization for empty Full diff: https://github.com/llvm/llvm-project/pull/140976.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..5e8421ed67d66 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2559,6 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
Location loc);
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..389950af821a2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6661,6 +6661,7 @@ LogicalResult MaskOp::verify() {
///
/// %0 = user_op %a : vector<8xf32>
///
+/// Empty `vector.mask` with passthru operand are handled by the canonicalizer.
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (!maskOp.isEmpty() || maskOp.hasPassthru())
@@ -6696,6 +6697,47 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
return success();
}
+// Canonialize empty `vector.mask` operations that can't be handled in
+// `VectorMask::fold`.
+//
+// Example 1: Empty `vector.mask` with passthru operand.
+//
+// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
+// vector<8xi1> -> vector<8xf32>
+//
+// becomes:
+//
+// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
+//
+class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MaskOp maskOp,
+ PatternRewriter &rewriter) const override {
+ if (!maskOp.isEmpty())
+ return failure();
+
+ if (!maskOp.hasPassthru())
+ return failure();
+
+ Block *block = maskOp.getMaskBlock();
+ auto terminator = cast<vector::YieldOp>(block->front());
+ assert(terminator.getNumOperands() == 1 &&
+ "expected one result when passthru is provided");
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(
+ maskOp, maskOp.getResultTypes(), maskOp.getMask(),
+ terminator.getOperand(0), maskOp.getPassthru());
+
+ return success();
+ }
+};
+
+void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<CanonializeEmptyMaskOp>(context);
+}
+
// MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 974f4506a2ef0..a6543aafd1c77 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
+func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
%idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
// CHECK: return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
+func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
// CHECK-SAME: %[[A:.*]]: vector<f32>
// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
// CHECK: return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
+func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
// CHECK: return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
// CHECK: return %[[B]] : vector<4xf32>
// rank(extract_output) < rank(broadcast_input)
-func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
+func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
// rank(extract_output) > rank(broadcast_input)
-func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
-> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
// CHECK: return %[[R]] : vector<8xf32>
// rank(extract_output) == rank(broadcast_input)
-func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
-> vector<8xf32> {
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
return %broadcast : vector<4x6xi8>
}
-// -----
+// -----
// CHECK-LABEL: broadcast_splat_constant
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
// -----
+// CHECK-LABEL: func @empty_vector_mask_with_passthru
+// CHECK-SAME: %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_passthru(%a : vector<8xf32>, %mask : vector<8xi1>,
+ %passthru : vector<8xf32>) -> vector<8xf32> {
+// CHECK-NOT: vector.mask
+// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
+// CHECK: return %[[SEL]] : vector<8xf32>
+ %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @all_true_vector_mask
// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {
|
|
@llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis MR adds a missing canonicalization for empty Full diff: https://github.com/llvm/llvm-project/pull/140976.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..5e8421ed67d66 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2559,6 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
Location loc);
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..389950af821a2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6661,6 +6661,7 @@ LogicalResult MaskOp::verify() {
///
/// %0 = user_op %a : vector<8xf32>
///
+/// Empty `vector.mask` with passthru operand are handled by the canonicalizer.
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (!maskOp.isEmpty() || maskOp.hasPassthru())
@@ -6696,6 +6697,47 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
return success();
}
+// Canonialize empty `vector.mask` operations that can't be handled in
+// `VectorMask::fold`.
+//
+// Example 1: Empty `vector.mask` with passthru operand.
+//
+// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
+// vector<8xi1> -> vector<8xf32>
+//
+// becomes:
+//
+// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
+//
+class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MaskOp maskOp,
+ PatternRewriter &rewriter) const override {
+ if (!maskOp.isEmpty())
+ return failure();
+
+ if (!maskOp.hasPassthru())
+ return failure();
+
+ Block *block = maskOp.getMaskBlock();
+ auto terminator = cast<vector::YieldOp>(block->front());
+ assert(terminator.getNumOperands() == 1 &&
+ "expected one result when passthru is provided");
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(
+ maskOp, maskOp.getResultTypes(), maskOp.getMask(),
+ terminator.getOperand(0), maskOp.getPassthru());
+
+ return success();
+ }
+};
+
+void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<CanonializeEmptyMaskOp>(context);
+}
+
// MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 974f4506a2ef0..a6543aafd1c77 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
+func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
%idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
// CHECK: return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
+func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
// CHECK-SAME: %[[A:.*]]: vector<f32>
// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
// CHECK: return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
+func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
// CHECK: return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
// CHECK: return %[[B]] : vector<4xf32>
// rank(extract_output) < rank(broadcast_input)
-func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
+func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
// rank(extract_output) > rank(broadcast_input)
-func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
-> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
// CHECK: return %[[R]] : vector<8xf32>
// rank(extract_output) == rank(broadcast_input)
-func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
-> vector<8xf32> {
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
return %broadcast : vector<4x6xi8>
}
-// -----
+// -----
// CHECK-LABEL: broadcast_splat_constant
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
// -----
+// CHECK-LABEL: func @empty_vector_mask_with_passthru
+// CHECK-SAME: %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_passthru(%a : vector<8xf32>, %mask : vector<8xi1>,
+ %passthru : vector<8xf32>) -> vector<8xf32> {
+// CHECK-NOT: vector.mask
+// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
+// CHECK: return %[[SEL]] : vector<8xf32>
+ %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @all_true_vector_mask
// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {
|
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % kind request to expand the documentation :)
Btw, did you mean arith.select in the PR title?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Could you document why this can't be handled by the folder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you document _why)?
vector.mask into vector.selectvector.mask into arith.select
This MR adds a missing canonicalization for empty `vector.mask` ops with
a passthru value.
```
%0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
vector<8xi1> -> vector<8xf32>
becomes:
%0 = arith.select %mask, %a, %passthru : vector<8xf32>
```
9a21ad5 to
fcf2152
Compare
This MR adds a missing canonicalization for empty
vector.maskops with a passthru value.