Skip to content

Conversation

@dcaballe
Copy link
Contributor

This MR moves the canonicalization that elides empty vector.mask ops to folders.

@llvmbot
Copy link
Member

llvmbot commented May 17, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-gpu

Author: Diego Caballero (dcaballe)

Changes

This MR moves the canonicalization that elides empty vector.mask ops to folders.


Full diff: https://github.com/llvm/llvm-project/pull/140324.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+33-35)
  • (modified) mlir/test/Conversion/GPUCommon/lower-vector.mlir (+4-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3aefcea8de994..2e0c9a6de11ae 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2554,7 +2554,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
                                  Location loc);
   }];
 
-  let hasCanonicalizer = 1;
+  let hasCanonicalizer = 0;
   let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..104459850d508 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6631,13 +6631,42 @@ LogicalResult MaskOp::verify() {
   return success();
 }
 
-/// Folds vector.mask ops with an all-true mask.
+/// Folds empty `vector.mask` with no passthru operand and with or without
+/// return values. For example:
+///
+///   %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
+///     vector<8xi1> -> vector<8xf32>
+///   %1 = user_op %0 : vector<8xf32>
+///
+/// becomes:
+///
+///   %0 = user_op %a : vector<8xf32>
+///
+/// `vector.mask` with a passthru is handled by the canonicalizer.
+///
+static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
+                                     SmallVectorImpl<OpFoldResult> &results) {
+  if (!maskOp.isEmpty() || maskOp.hasPassthru())
+    return failure();
+
+  Block *block = maskOp.getMaskBlock();
+  auto terminator = cast<vector::YieldOp>(block->front());
+  if (terminator.getNumOperands() == 0) {
+    // `vector.mask` has no results, just remove the `vector.mask`.
+    return success();
+  }
+
+  // `vector.mask` has results, propagate the results.
+  llvm::append_range(results, terminator.getOperands());
+  return success();
+}
+
 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
                            SmallVectorImpl<OpFoldResult> &results) {
-  MaskFormat maskFormat = getMaskFormat(getMask());
-  if (isEmpty())
-    return failure();
+  if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
+    return success();
 
+  MaskFormat maskFormat = getMaskFormat(getMask());
   if (maskFormat != MaskFormat::AllTrue)
     return failure();
 
@@ -6650,37 +6679,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
-// Elides empty vector.mask operations with or without return values. Propagates
-// the yielded values by the vector.yield terminator, if any, or erases the op,
-// otherwise.
-class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(MaskOp maskOp,
-                                PatternRewriter &rewriter) const override {
-    auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
-    if (maskingOp.getMaskableOp())
-      return failure();
-
-    if (!maskOp.isEmpty())
-      return failure();
-
-    Block *block = maskOp.getMaskBlock();
-    auto terminator = cast<vector::YieldOp>(block->front());
-    if (terminator.getNumOperands() == 0)
-      rewriter.eraseOp(maskOp);
-    else
-      rewriter.replaceOp(maskOp, terminator.getOperands());
-
-    return success();
-  }
-};
-
-void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                         MLIRContext *context) {
-  results.add<ElideEmptyMaskOp>(context);
-}
-
 // MaskingOpInterface definitions.
 
 /// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 532a2383cea9e..b4e3da9d0dbfe 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,10 +1,12 @@
 // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
 
 module {
+  // CHECK-LABEL: func @func
+  // CHECK-SAME: %[[IN:.*]]: vector<11xf32>
   func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
     %cst_41 = arith.constant dense<true> : vector<11xi1>
-    // CHECK: vector.mask
-    // CHECK-SAME: vector.yield %arg0
+    // CHECK-NOT: vector.mask
+    // CHECK: return %[[IN]] : vector<11xf32>
     %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
     return %127 : vector<11xf32>
   }

@llvmbot
Copy link
Member

llvmbot commented May 17, 2025

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This MR moves the canonicalization that elides empty vector.mask ops to folders.


Full diff: https://github.com/llvm/llvm-project/pull/140324.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+33-35)
  • (modified) mlir/test/Conversion/GPUCommon/lower-vector.mlir (+4-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3aefcea8de994..2e0c9a6de11ae 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2554,7 +2554,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
                                  Location loc);
   }];
 
-  let hasCanonicalizer = 1;
+  let hasCanonicalizer = 0;
   let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..104459850d508 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6631,13 +6631,42 @@ LogicalResult MaskOp::verify() {
   return success();
 }
 
-/// Folds vector.mask ops with an all-true mask.
+/// Folds empty `vector.mask` with no passthru operand and with or without
+/// return values. For example:
+///
+///   %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
+///     vector<8xi1> -> vector<8xf32>
+///   %1 = user_op %0 : vector<8xf32>
+///
+/// becomes:
+///
+///   %0 = user_op %a : vector<8xf32>
+///
+/// `vector.mask` with a passthru is handled by the canonicalizer.
+///
+static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
+                                     SmallVectorImpl<OpFoldResult> &results) {
+  if (!maskOp.isEmpty() || maskOp.hasPassthru())
+    return failure();
+
+  Block *block = maskOp.getMaskBlock();
+  auto terminator = cast<vector::YieldOp>(block->front());
+  if (terminator.getNumOperands() == 0) {
+    // `vector.mask` has no results, just remove the `vector.mask`.
+    return success();
+  }
+
+  // `vector.mask` has results, propagate the results.
+  llvm::append_range(results, terminator.getOperands());
+  return success();
+}
+
 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
                            SmallVectorImpl<OpFoldResult> &results) {
-  MaskFormat maskFormat = getMaskFormat(getMask());
-  if (isEmpty())
-    return failure();
+  if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
+    return success();
 
+  MaskFormat maskFormat = getMaskFormat(getMask());
   if (maskFormat != MaskFormat::AllTrue)
     return failure();
 
@@ -6650,37 +6679,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
-// Elides empty vector.mask operations with or without return values. Propagates
-// the yielded values by the vector.yield terminator, if any, or erases the op,
-// otherwise.
-class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(MaskOp maskOp,
-                                PatternRewriter &rewriter) const override {
-    auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
-    if (maskingOp.getMaskableOp())
-      return failure();
-
-    if (!maskOp.isEmpty())
-      return failure();
-
-    Block *block = maskOp.getMaskBlock();
-    auto terminator = cast<vector::YieldOp>(block->front());
-    if (terminator.getNumOperands() == 0)
-      rewriter.eraseOp(maskOp);
-    else
-      rewriter.replaceOp(maskOp, terminator.getOperands());
-
-    return success();
-  }
-};
-
-void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                         MLIRContext *context) {
-  results.add<ElideEmptyMaskOp>(context);
-}
-
 // MaskingOpInterface definitions.
 
 /// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 532a2383cea9e..b4e3da9d0dbfe 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,10 +1,12 @@
 // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
 
 module {
+  // CHECK-LABEL: func @func
+  // CHECK-SAME: %[[IN:.*]]: vector<11xf32>
   func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
     %cst_41 = arith.constant dense<true> : vector<11xi1>
-    // CHECK: vector.mask
-    // CHECK-SAME: vector.yield %arg0
+    // CHECK-NOT: vector.mask
+    // CHECK: return %[[IN]] : vector<11xf32>
     %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
     return %127 : vector<11xf32>
   }

// CHECK: vector.mask
// CHECK-SAME: vector.yield %arg0
// CHECK-NOT: vector.mask
// CHECK: return %[[IN]] : vector<11xf32>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the intent of this test is but I'm updating it accordingly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a regression test originally added in 6904998 . It’s a great example of why "regression" on its own isn’t a helpful test description 😅

That commit fixed a bug in ConvertVectorToLLVM.cpp - it’s unclear why the test landed under the "GPUCommon" subdirectory. Unfortunately, with no comments and uninformative function/variable names, it’s hard to tell what specific edge case this was meant to cover - which makes changes like this unnecessarily tricky to review.

From the context of the original commit, it seems the intent was to test the lowering of empty vector.mask. If so, I think this would be better placed in: mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

WDYT? Would you mind moving it there instead?

Also - if I’ve understood the original intent correctly, your update here makes sense 🙂

This MR moves the canonicalization that elides empty `vector.mask` ops
to folders.
@dcaballe dcaballe force-pushed the vector-mask-canon-to-fold branch from 1014ee5 to 3bc4083 Compare May 20, 2025 22:33
@dcaballe
Copy link
Contributor Author

ping :)

}];

let hasCanonicalizer = 1;
let hasCanonicalizer = 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this line anymore? I thibk we can remove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm adding that back in my next PR but I removed it from now.

///
/// %0 = user_op %a : vector<8xf32>
///
/// `vector.mask` with a passthru is handled by the canonicalizer.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But vector.mask doesnt gave a canonicalizer anymore, right?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping and sorry for the delay!

Makes sense, thanks for this simplification! I've left one suggestion re the test, otherwise LGTM (from what I can tell, all the prior comments have already been addressed).

// CHECK: vector.mask
// CHECK-SAME: vector.yield %arg0
// CHECK-NOT: vector.mask
// CHECK: return %[[IN]] : vector<11xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a regression test originally added in 6904998 . It’s a great example of why "regression" on its own isn’t a helpful test description 😅

That commit fixed a bug in ConvertVectorToLLVM.cpp - it’s unclear why the test landed under the "GPUCommon" subdirectory. Unfortunately, with no comments and uninformative function/variable names, it’s hard to tell what specific edge case this was meant to cover - which makes changes like this unnecessarily tricky to review.

From the context of the original commit, it seems the intent was to test the lowering of empty vector.mask. If so, I think this would be better placed in: mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

WDYT? Would you mind moving it there instead?

Also - if I’ve understood the original intent correctly, your update here makes sense 🙂

@dcaballe dcaballe merged commit d6f394e into llvm:main May 22, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants