@@ -5643,49 +5643,24 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56435643 return bcastOp.getSource ();
56445644 }
56455645
5646- return {};
5647- }
5648-
5649- namespace {
5650- // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5651- class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5652- public:
5653- using OpRewritePattern::OpRewritePattern;
5654-
5655- LogicalResult matchAndRewrite (ShapeCastOp shapeCastOp,
5656- PatternRewriter &rewriter) const override {
5657- auto constantOp =
5658- shapeCastOp.getSource ().getDefiningOp <arith::ConstantOp>();
5659- if (!constantOp)
5660- return failure ();
5661- // Only handle splat for now.
5662- auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue ());
5663- if (!dense)
5664- return failure ();
5665- auto newAttr =
5666- DenseElementsAttr::get (llvm::cast<VectorType>(shapeCastOp.getType ()),
5667- dense.getSplatValue <Attribute>());
5668- rewriter.replaceOpWithNewOp <arith::ConstantOp>(shapeCastOp, newAttr);
5669- return success ();
5646+ // Replace shape_cast(arith.constant) with arith.constant. Currently only
5647+ // handles splat constants.
5648+ if (auto constantOp = getSource ().getDefiningOp <arith::ConstantOp>()) {
5649+ if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue ())) {
5650+ return DenseElementsAttr::get (cast<VectorType>(getType ()),
5651+ dense.getSplatValue <Attribute>());
5652+ }
56705653 }
5671- };
56725654
5673- // Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
5674- class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
5675- public:
5676- using OpRewritePattern::OpRewritePattern;
5677-
5678- LogicalResult matchAndRewrite (ShapeCastOp shapeCastOp,
5679- PatternRewriter &rewriter) const override {
5655+ // Replace shape_cast(poison) with poison.
5656+ if (getSource ().getDefiningOp <ub::PoisonOp>()) {
5657+ return ub::PoisonAttr::get (getContext ());
5658+ }
56805659
5681- if (!shapeCastOp. getSource (). getDefiningOp <ub::PoisonOp>())
5682- return failure ();
5660+ return {};
5661+ }
56835662
5684- rewriter.replaceOpWithNewOp <ub::PoisonOp>(shapeCastOp,
5685- shapeCastOp.getType ());
5686- return success ();
5687- }
5688- };
5663+ namespace {
56895664
56905665// / Helper function that computes a new vector type based on the input vector
56915666// / type by removing the trailing one dims:
@@ -5846,8 +5821,7 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
58465821void ShapeCastOp::getCanonicalizationPatterns (RewritePatternSet &results,
58475822 MLIRContext *context) {
58485823 results
5849- .add <ShapeCastConstantFolder, ShapeCastPoisonFolder,
5850- ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
5824+ .add <ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
58515825 context);
58525826}
58535827
0 commit comments