Skip to content

Conversation

@dcaballe
Copy link
Contributor

This MR adds a missing canonicalization for empty vector.mask ops with a passthru value.

   %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
     vector<8xi1> -> vector<8xf32>

 becomes:

   %0 = arith.select %mask, %a, %passthru : vector<8xf32>

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This MR adds a missing canonicalization for empty vector.mask ops with a passthru value.

   %0 = vector.mask %mask, %passthru { vector.yield %a : vector&lt;8xf32&gt; } :
     vector&lt;8xi1&gt; -&gt; vector&lt;8xf32&gt;

 becomes:

   %0 = arith.select %mask, %a, %passthru : vector&lt;8xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/140976.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+42)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+21-8)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..5e8421ed67d66 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2559,6 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
                                  Location loc);
   }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..389950af821a2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6661,6 +6661,7 @@ LogicalResult MaskOp::verify() {
 ///
 ///   %0 = user_op %a : vector<8xf32>
 ///
+/// Empty `vector.mask` with passthru operand are handled by the canonicalizer.
 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
                                      SmallVectorImpl<OpFoldResult> &results) {
   if (!maskOp.isEmpty() || maskOp.hasPassthru())
@@ -6696,6 +6697,47 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
+// Canonialize empty `vector.mask` operations that can't be handled in
+// `VectorMask::fold`.
+//
+// Example 1: Empty `vector.mask` with passthru operand.
+//
+//   %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
+//     vector<8xi1> -> vector<8xf32>
+//
+// becomes:
+//
+//   %0 = arith.select %mask, %a, %passthru : vector<8xf32>
+//
+class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MaskOp maskOp,
+                                PatternRewriter &rewriter) const override {
+    if (!maskOp.isEmpty())
+      return failure();
+
+    if (!maskOp.hasPassthru())
+      return failure();
+
+    Block *block = maskOp.getMaskBlock();
+    auto terminator = cast<vector::YieldOp>(block->front());
+    assert(terminator.getNumOperands() == 1 &&
+           "expected one result when passthru is provided");
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(
+        maskOp, maskOp.getResultTypes(), maskOp.getMask(),
+        terminator.getOperand(0), maskOp.getPassthru());
+
+    return success();
+  }
+};
+
+void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<CanonializeEmptyMaskOp>(context);
+}
+
 // MaskingOpInterface definitions.
 
 /// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 974f4506a2ef0..a6543aafd1c77 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
 // CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32, 
+func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
   %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
 // CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
 //       CHECK:   return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>, 
+func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
 //       CHECK:   %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
 //       CHECK:   return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
+func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
   %idx0 : index, %idx1 : index, %idx2: index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
 //  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
 //       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
 //       CHECK:   return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, 
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
   %idx : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 //       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
 // rank(extract_output) < rank(broadcast_input)
-func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, 
+func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
 // rank(extract_output) > rank(broadcast_input)
-func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index) 
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
   -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
 // rank(extract_output) == rank(broadcast_input)
-func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index) 
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
   -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
   %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
   return %broadcast : vector<4x6xi8>
 }
 
-// ----- 
+// -----
 
 // CHECK-LABEL:  broadcast_splat_constant
 //       CHECK:  %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
 
 // -----
 
+// CHECK-LABEL: func @empty_vector_mask_with_passthru
+//  CHECK-SAME:     %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_passthru(%a : vector<8xf32>, %mask : vector<8xi1>,
+                                           %passthru : vector<8xf32>) -> vector<8xf32> {
+//   CHECK-NOT:   vector.mask
+//       CHECK:   %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
+//       CHECK:   return %[[SEL]] : vector<8xf32>
+  %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @all_true_vector_mask
 //  CHECK-SAME:     %[[IN:.*]]: tensor<3x4xf32>
 func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

This MR adds a missing canonicalization for empty vector.mask ops with a passthru value.

   %0 = vector.mask %mask, %passthru { vector.yield %a : vector&lt;8xf32&gt; } :
     vector&lt;8xi1&gt; -&gt; vector&lt;8xf32&gt;

 becomes:

   %0 = arith.select %mask, %a, %passthru : vector&lt;8xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/140976.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+42)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+21-8)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..5e8421ed67d66 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2559,6 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
                                  Location loc);
   }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..389950af821a2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6661,6 +6661,7 @@ LogicalResult MaskOp::verify() {
 ///
 ///   %0 = user_op %a : vector<8xf32>
 ///
+/// Empty `vector.mask` with passthru operand are handled by the canonicalizer.
 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
                                      SmallVectorImpl<OpFoldResult> &results) {
   if (!maskOp.isEmpty() || maskOp.hasPassthru())
@@ -6696,6 +6697,47 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
+// Canonialize empty `vector.mask` operations that can't be handled in
+// `VectorMask::fold`.
+//
+// Example 1: Empty `vector.mask` with passthru operand.
+//
+//   %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
+//     vector<8xi1> -> vector<8xf32>
+//
+// becomes:
+//
+//   %0 = arith.select %mask, %a, %passthru : vector<8xf32>
+//
+class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MaskOp maskOp,
+                                PatternRewriter &rewriter) const override {
+    if (!maskOp.isEmpty())
+      return failure();
+
+    if (!maskOp.hasPassthru())
+      return failure();
+
+    Block *block = maskOp.getMaskBlock();
+    auto terminator = cast<vector::YieldOp>(block->front());
+    assert(terminator.getNumOperands() == 1 &&
+           "expected one result when passthru is provided");
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(
+        maskOp, maskOp.getResultTypes(), maskOp.getMask(),
+        terminator.getOperand(0), maskOp.getPassthru());
+
+    return success();
+  }
+};
+
+void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<CanonializeEmptyMaskOp>(context);
+}
+
 // MaskingOpInterface definitions.
 
 /// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 974f4506a2ef0..a6543aafd1c77 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
 // CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32, 
+func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
   %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
 // CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
 //       CHECK:   return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>, 
+func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
 //       CHECK:   %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
 //       CHECK:   return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
+func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
   %idx0 : index, %idx1 : index, %idx2: index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
 //  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
 //       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
 //       CHECK:   return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, 
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
   %idx : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 //       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
 // rank(extract_output) < rank(broadcast_input)
-func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, 
+func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
 // rank(extract_output) > rank(broadcast_input)
-func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index) 
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
   -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
 // rank(extract_output) == rank(broadcast_input)
-func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index) 
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
   -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
   %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
   return %broadcast : vector<4x6xi8>
 }
 
-// ----- 
+// -----
 
 // CHECK-LABEL:  broadcast_splat_constant
 //       CHECK:  %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
 
 // -----
 
+// CHECK-LABEL: func @empty_vector_mask_with_passthru
+//  CHECK-SAME:     %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_passthru(%a : vector<8xf32>, %mask : vector<8xi1>,
+                                           %passthru : vector<8xf32>) -> vector<8xf32> {
+//   CHECK-NOT:   vector.mask
+//       CHECK:   %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
+//       CHECK:   return %[[SEL]] : vector<8xf32>
+  %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @all_true_vector_mask
 //  CHECK-SAME:     %[[IN:.*]]: tensor<3x4xf32>
 func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % kind request to expand the documentation :)

Btw, did you mean arith.select in the PR title?

Comment on lines 6700 to 6701
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Could you document why this can't be handled by the folder?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document _why)?

@dcaballe dcaballe changed the title [mlir][Vector] Canonicalize empty vector.mask into vector.select [mlir][Vector] Canonicalize empty vector.mask into arith.select May 22, 2025
dcaballe added 2 commits May 22, 2025 16:01
This MR adds a missing canonicalization for empty `vector.mask` ops with
a passthru value.

```
   %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
     vector<8xi1> -> vector<8xf32>

 becomes:

   %0 = arith.select %mask, %a, %passthru : vector<8xf32>
```
@dcaballe dcaballe force-pushed the dcaballero/fold-empty-vector-mask-passthru branch from 9a21ad5 to fcf2152 Compare May 22, 2025 16:06
@dcaballe dcaballe merged commit 204eb70 into llvm:main May 23, 2025
9 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants