Skip to content

Commit 99ed635

Browse files
committed
use folders where possible (replace 2 canons)
Signed-off-by: James Newling <[email protected]>
1 parent 2ddaa8e commit 99ed635

File tree

1 file changed

+15
-41
lines changed

1 file changed

+15
-41
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5643,49 +5643,24 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56435643
return bcastOp.getSource();
56445644
}
56455645

5646-
return {};
5647-
}
5648-
5649-
namespace {
5650-
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5651-
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5652-
public:
5653-
using OpRewritePattern::OpRewritePattern;
5654-
5655-
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5656-
PatternRewriter &rewriter) const override {
5657-
auto constantOp =
5658-
shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
5659-
if (!constantOp)
5660-
return failure();
5661-
// Only handle splat for now.
5662-
auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
5663-
if (!dense)
5664-
return failure();
5665-
auto newAttr =
5666-
DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
5667-
dense.getSplatValue<Attribute>());
5668-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
5669-
return success();
5646+
// Replace shape_cast(arith.constant) with arith.constant. Currently only
5647+
// handles splat constants.
5648+
if (auto constantOp = getSource().getDefiningOp<arith::ConstantOp>()) {
5649+
if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue())) {
5650+
return DenseElementsAttr::get(cast<VectorType>(getType()),
5651+
dense.getSplatValue<Attribute>());
5652+
}
56705653
}
5671-
};
56725654

5673-
// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
5674-
class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
5675-
public:
5676-
using OpRewritePattern::OpRewritePattern;
5677-
5678-
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
5679-
PatternRewriter &rewriter) const override {
5655+
// Replace shape_cast(poison) with poison.
5656+
if (getSource().getDefiningOp<ub::PoisonOp>()) {
5657+
return ub::PoisonAttr::get(getContext());
5658+
}
56805659

5681-
if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
5682-
return failure();
5660+
return {};
5661+
}
56835662

5684-
rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
5685-
shapeCastOp.getType());
5686-
return success();
5687-
}
5688-
};
5663+
namespace {
56895664

56905665
/// Helper function that computes a new vector type based on the input vector
56915666
/// type by removing the trailing one dims:
@@ -5846,8 +5821,7 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
58465821
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
58475822
MLIRContext *context) {
58485823
results
5849-
.add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
5850-
ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
5824+
.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
58515825
context);
58525826
}
58535827

0 commit comments

Comments
 (0)