@@ -5617,18 +5617,20 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56175617 if (getSource ().getType () == getType ())
56185618 return getSource ();
56195619
5620+ VectorType resultType = getType ();
5621+
56205622 // Canceling shape casts.
56215623 if (auto otherOp = getSource ().getDefiningOp <ShapeCastOp>()) {
56225624
56235625 // Only allows valid transitive folding (expand/collapse dimensions).
56245626 VectorType srcType = otherOp.getSource ().getType ();
5625- if (getType () == srcType)
5627+ if (resultType == srcType)
56265628 return otherOp.getSource ();
5627- if (srcType.getRank () < getType () .getRank ()) {
5628- if (!isValidShapeCast (srcType.getShape (), getType () .getShape ()))
5629+ if (srcType.getRank () < resultType .getRank ()) {
5630+ if (!isValidShapeCast (srcType.getShape (), resultType .getShape ()))
56295631 return {};
5630- } else if (srcType.getRank () > getType () .getRank ()) {
5631- if (!isValidShapeCast (getType () .getShape (), srcType.getShape ()))
5632+ } else if (srcType.getRank () > resultType .getRank ()) {
5633+ if (!isValidShapeCast (resultType .getShape (), srcType.getShape ()))
56325634 return {};
56335635 } else {
56345636 return {};
@@ -5639,14 +5641,14 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56395641
56405642 // Cancelling broadcast and shape cast ops.
56415643 if (auto bcastOp = getSource ().getDefiningOp <BroadcastOp>()) {
5642- if (bcastOp.getSourceType () == getType () )
5644+ if (bcastOp.getSourceType () == resultType )
56435645 return bcastOp.getSource ();
56445646 }
56455647
56465648 // shape_cast(constant) -> constant
56475649 if (auto splatAttr =
56485650 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource ())) {
5649- return DenseElementsAttr::get (getType () ,
5651+ return DenseElementsAttr::get (resultType ,
56505652 splatAttr.getSplatValue <Attribute>());
56515653 }
56525654
0 commit comments