-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir] canonicalizer: shape_cast(poison) -> poison #133988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesBased on the ShapeCastConstantFolder, this pattern replaces %0 = ub.poison : vector<2x3xf32> with %1 = ub.poison : vector<6xf32> Full diff: https://github.com/llvm/llvm-project/pull/133988.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..ee7df8a943d24 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5646,6 +5646,23 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
}
};
+// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
+class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+
+ if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
+ shapeCastOp.getType());
+ return success();
+ }
+};
+
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
///
@@ -5804,8 +5821,10 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
- ShapeCastBroadcastFolder>(context);
+ results
+ .add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
+ ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..72064fb42741a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// CHECK-LABEL: shape_cast_poison
+// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
+// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
+// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+ %poison = ub.poison : vector<5x4x2xf32>
+ %poison_1 = ub.poison : vector<12x2xi32>
+ %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
+ %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
+ return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
// CHECK-LABEL: extract_strided_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>
|
Groverkss
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Just a request for change in implementation, the idea of the change by itself is good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a folder, not a canonicalizer. You can just return UBPoisonAttr in the folder and the materialization will automatically create a ub::PoisonOp if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I've moved the equivalent constant canonicalizer to a folder for consistency
287e581 to
356204d
Compare
| // Replace shape_cast(arith.constant) with arith.constant. Currently only | ||
| // handles splat constants. | ||
| if (auto constantOp = getSource().getDefiningOp<arith::ConstantOp>()) { | ||
| if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue())) { | ||
| return DenseElementsAttr::get(cast<VectorType>(getType()), | ||
| dense.getSplatValue<Attribute>()); | ||
| } | ||
| } | ||
|
|
||
| // Replace shape_cast(poison) with poison. | ||
| if (getSource().getDefiningOp<ub::PoisonOp>()) { | ||
| return ub::PoisonAttr::get(getContext()); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to check getDefiningOp, the adaptor will automatically give you an attribute if it can. Instead, you can do:
if (auto dense = llvm::dyn_cast(adaptor.getSource()) {
...
}
Same for ub::PoisonOp, but use ub::PosionAttr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. Note I needed to '_if_present' it (adaptor might not have an attribute)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some NFC here (remove redundant casts)
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would keep the resultType variable since the type is used multiple times and makes the reading easier
Groverkss
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
e774103 to
d47886c
Compare
Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
Signed-off-by: James Newling <[email protected]>
Following on from #133988 --------- Signed-off-by: James Newling <[email protected]>
Following on from llvm/llvm-project#133988 --------- Signed-off-by: James Newling <[email protected]>
Based on the ShapeCastConstantFolder, this pattern replaces
%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>
with
%1 = ub.poison : vector<6xf32>