4242#include " llvm/ADT/SmallVector.h"
4343#include " llvm/ADT/StringSet.h"
4444#include " llvm/ADT/TypeSwitch.h"
45+ #include " llvm/Support/Casting.h"
4546
4647#include < cassert>
4748#include < cstdint>
@@ -5611,28 +5612,27 @@ LogicalResult ShapeCastOp::verify() {
56115612}
56125613
56135614OpFoldResult ShapeCastOp::fold (FoldAdaptor adaptor) {
5615+
56145616 // No-op shape cast.
5615- if (getSource ().getType () == getResult (). getType ())
5617+ if (getSource ().getType () == getType ())
56165618 return getSource ();
56175619
56185620 // Canceling shape casts.
56195621 if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
5620- if (getResult ().getType () == otherOp.getSource ().getType ())
5621- return otherOp.getSource ();
56225622
5623- // Only allows valid transitive folding.
5624- VectorType srcType = llvm::cast<VectorType>(otherOp.getSource ().getType ());
5625- VectorType resultType = llvm::cast<VectorType>(getResult ().getType ());
5626- if (srcType.getRank () < resultType.getRank ()) {
5627- if (!isValidShapeCast (srcType.getShape (), resultType.getShape ()))
5623+ // Only allows valid transitive folding (expand/collapse dimensions).
5624+ VectorType srcType = otherOp.getSource ().getType ();
5625+ if (getType () == srcType)
5626+ return otherOp.getSource ();
5627+ if (srcType.getRank () < getType ().getRank ()) {
5628+ if (!isValidShapeCast (srcType.getShape (), getType ().getShape ()))
56285629 return {};
5629- } else if (srcType.getRank () > resultType .getRank ()) {
5630- if (!isValidShapeCast (resultType .getShape (), srcType.getShape ()))
5630+ } else if (srcType.getRank () > getType () .getRank ()) {
5631+ if (!isValidShapeCast (getType () .getShape (), srcType.getShape ()))
56315632 return {};
56325633 } else {
56335634 return {};
56345635 }
5635-
56365636 setOperand (otherOp.getSource ());
56375637 return getResult ();
56385638 }
@@ -5643,17 +5643,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56435643 return bcastOp.getSource ();
56445644 }
56455645
5646- // Replace shape_cast(arith.constant) with arith.constant. Currently only
5647- // handles splat constants.
5648- if (auto constantOp = getSource ().getDefiningOp <arith::ConstantOp>()) {
5649- if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue ())) {
5650- return DenseElementsAttr::get (cast<VectorType>(getType ()),
5651- dense.getSplatValue <Attribute>());
5652- }
5646+ // shape_cast(constant) -> constant
5647+ if (auto splatAttr =
5648+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ())) {
5649+ return DenseElementsAttr::get (getType (),
5650+ splatAttr.getSplatValue <Attribute>());
56535651 }
56545652
5655- // Replace shape_cast(poison) with poison.
5656- if (getSource (). getDefiningOp <ub::PoisonOp>( )) {
5653+ // shape_cast(poison) -> poison
5654+ if (llvm::dyn_cast_if_present <ub::PoisonAttr>(adaptor. getSource () )) {
56575655 return ub::PoisonAttr::get (getContext ());
56585656 }
56595657
0 commit comments