4242#include " llvm/ADT/SmallVector.h"
4343#include " llvm/ADT/StringSet.h"
4444#include " llvm/ADT/TypeSwitch.h"
45+ #include " llvm/Support/Casting.h"
4546
4647#include < cassert>
4748#include < cstdint>
@@ -5611,18 +5612,20 @@ LogicalResult ShapeCastOp::verify() {
56115612}
56125613
56135614OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
5615+
56145616 // No-op shape cast.
5615- if (getSource ().getType () == getResult (). getType ())
5617+ if (getSource ().getType () == getType ())
56165618 return getSource ();
56175619
5620+ VectorType resultType = getType ();
5621+
56185622 // Canceling shape casts.
56195623 if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5620- if (getResult ().getType () == otherOp.getSource ().getType ())
5621- return otherOp.getSource ();
56225624
5623- // Only allows valid transitive folding.
5624- VectorType srcType = llvm::cast<VectorType>(otherOp.getSource ().getType ());
5625- VectorType resultType = llvm::cast<VectorType>(getResult ().getType ());
5625+ // Only allows valid transitive folding (expand/collapse dimensions).
5626+ VectorType srcType = otherOp.getSource ().getType ();
5627+ if (resultType == srcType)
5628+ return otherOp.getSource ();
56265629 if (srcType.getRank () < resultType.getRank ()) {
56275630 if (!isValidShapeCast (srcType.getShape (), resultType.getShape ()))
56285631 return {};
@@ -5632,43 +5635,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56325635 } else {
56335636 return {};
56345637 }
5635-
56365638 setOperand (otherOp.getSource ());
56375639 return getResult ();
56385640 }
56395641
56405642 // Cancelling broadcast and shape cast ops.
56415643 if (auto bcastOp = getSource ().getDefiningOp <BroadcastOp>()) {
5642- if (bcastOp.getSourceType () == getType () )
5644+ if (bcastOp.getSourceType () == resultType )
56435645 return bcastOp.getSource ();
56445646 }
56455647
5648+ // shape_cast(constant) -> constant
5649+ if (auto splatAttr =
5650+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ())) {
5651+ return DenseElementsAttr::get (resultType,
5652+ splatAttr.getSplatValue <Attribute>());
5653+ }
5654+
5655+ // shape_cast(poison) -> poison
5656+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource ())) {
5657+ return ub::PoisonAttr::get (getContext ());
5658+ }
5659+
56465660 return {};
56475661}
56485662
56495663namespace {
5650- // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
5651- class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
5652- public:
5653- using OpRewritePattern::OpRewritePattern;
5654-
5655- LogicalResult matchAndRewrite (ShapeCastOp shapeCastOp,
5656- PatternRewriter &rewriter) const override {
5657- auto constantOp =
5658- shapeCastOp.getSource ().getDefiningOp <arith::ConstantOp>();
5659- if (!constantOp)
5660- return failure ();
5661- // Only handle splat for now.
5662- auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue ());
5663- if (!dense)
5664- return failure ();
5665- auto newAttr =
5666- DenseElementsAttr::get (llvm::cast<VectorType>(shapeCastOp.getType ()),
5667- dense.getSplatValue <Attribute>());
5668- rewriter.replaceOpWithNewOp <arith::ConstantOp>(shapeCastOp, newAttr);
5669- return success ();
5670- }
5671- };
56725664
56735665// / Helper function that computes a new vector type based on the input vector
56745666// / type by removing the trailing one dims:
@@ -5828,8 +5820,9 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
58285820
58295821void ShapeCastOp::getCanonicalizationPatterns (RewritePatternSet &results,
58305822 MLIRContext *context) {
5831- results.add <ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
5832- ShapeCastBroadcastFolder>(context);
5823+ results
5824+ .add <ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
5825+ context);
58335826}
58345827
58355828// ===----------------------------------------------------------------------===//
0 commit comments