Skip to content

Commit 2ddaa8e

Browse files
committed
add canonicalizer
1 parent fe2eefc commit 2ddaa8e

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5670,6 +5670,23 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
56705670
}
56715671
};
56725672

5673+
// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
5674+
class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
5675+
public:
5676+
using OpRewritePattern::OpRewritePattern;
5677+
5678+
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5679+
PatternRewriter &rewriter) const override {
5680+
5681+
if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
5682+
return failure();
5683+
5684+
rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
5685+
shapeCastOp.getType());
5686+
return success();
5687+
}
5688+
};
5689+
56735690
/// Helper function that computes a new vector type based on the input vector
56745691
/// type by removing the trailing one dims:
56755692
///
@@ -5828,8 +5845,10 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
58285845

58295846
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
58305847
MLIRContext *context) {
5831-
results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5832-
ShapeCastBroadcastFolder>(context);
5848+
results
5849+
.add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
5850+
ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
5851+
context);
58335852
}
58345853

58355854
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
11671167

11681168
// -----
11691169

1170+
// CHECK-LABEL: shape_cast_poison
1171+
// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
1172+
// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
1173+
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
1174+
func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
1175+
%poison = ub.poison : vector<5x4x2xf32>
1176+
%poison_1 = ub.poison : vector<12x2xi32>
1177+
%0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
1178+
%1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
1179+
return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
1180+
}
1181+
1182+
// -----
1183+
11701184
// CHECK-LABEL: extract_strided_constant
11711185
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
11721186
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>

0 commit comments

Comments
 (0)