Skip to content

Conversation

@newling
Copy link
Contributor

@newling newling commented Apr 1, 2025

Based on the ShapeCastConstantFolder, this pattern replaces

%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>

with

%1 = ub.poison : vector<6xf32>

@llvmbot
Copy link
Member

llvmbot commented Apr 1, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

Based on the ShapeCastConstantFolder, this pattern replaces

%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>

with

%1 = ub.poison : vector<6xf32>


Full diff: https://github.com/llvm/llvm-project/pull/133988.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+21-2)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..ee7df8a943d24 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5646,6 +5646,23 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
+class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
+                                              shapeCastOp.getType());
+    return success();
+  }
+};
+
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
 ///
@@ -5804,8 +5821,10 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
-              ShapeCastBroadcastFolder>(context);
+  results
+      .add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
+           ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..72064fb42741a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
 
 // -----
 
+// CHECK-LABEL: shape_cast_poison
+//       CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
+//       CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
+//       CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+  %poison = ub.poison : vector<5x4x2xf32>
+  %poison_1 = ub.poison : vector<12x2xi32>
+  %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
+  %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
+  return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_strided_constant
 //       CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
 //       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Just a request for change in implementation, the idea of the change by itself is good.

Comment on lines 5673 to 5664
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a folder, not a canonicalizer. You can just return UBPoisonAttr in the folder and the materialization will automatically create a ub::PoisonOp if needed

https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I've moved the equivalent constant canonicalizer to a folder for consistency

@newling newling force-pushed the shape_cast_poison_canon branch from 287e581 to 356204d Compare April 7, 2025 16:00
Comment on lines 5646 to 5658
// Replace shape_cast(arith.constant) with arith.constant. Currently only
// handles splat constants.
if (auto constantOp = getSource().getDefiningOp<arith::ConstantOp>()) {
if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue())) {
return DenseElementsAttr::get(cast<VectorType>(getType()),
dense.getSplatValue<Attribute>());
}
}

// Replace shape_cast(poison) with poison.
if (getSource().getDefiningOp<ub::PoisonOp>()) {
return ub::PoisonAttr::get(getContext());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to check getDefiningOp, the adaptor will automatically give you an attribute if it can. Instead, you can do:

if (auto dense = llvm::dyn_cast(adaptor.getSource()) {
...
}

Same for ub::PoisonOp, but use ub::PosionAttr

Copy link
Contributor Author

@newling newling Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Note I needed to '_if_present' it (adaptor might not have an attribute)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some NFC here (remove redundant casts)

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the resultType variable since the type is used multiple times and makes the reading easier

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@newling newling force-pushed the shape_cast_poison_canon branch from e774103 to d47886c Compare April 9, 2025 20:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants