@@ -1735,20 +1735,21 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
17351735 // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
17361736 // formed IR, it may not be identical (dynamically vs statically shaped),
17371737 // in which case it needs to be cast first using 'tensor.cast'.
1738- // Additionally, it may not have identical element type (i32 vs index)
1739- // while it has identical shaped type (dynamic vs static), in which case it needs
1740- // to be cast first using 'arith.index_cast'.
1741- // Note: 'shape.shape_of' op result must be shape or extent tensor.
1738+ // Additionally, it may not have identical element type (i32 vs index)
1739+ // while it has identical shaped type (dynamic vs static), in which case it
1740+ // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1741+ // op result must be shape or extent tensor.
17421742 Value shape = tensorReshapeOp.getShape ();
17431743
17441744 auto opTensorTy = cast<RankedTensorType>(op.getType ());
17451745 auto shapeTensorTy = cast<RankedTensorType>(shape.getType ());
17461746
17471747 if (opTensorTy != shapeTensorTy) {
1748- if (opTensorTy.getElementType () == shapeTensorTy.getElementType ())
1749- shape = rewriter.create <tensor::CastOp>(op.getLoc (), opTensorTy, shape);
1750- else if (!isExtentTensorType (shapeTensorTy))
1751- shape = rewriter.create <arith::IndexCastOp>(op.getLoc (), opTensorTy, shape);
1748+ if (opTensorTy.getElementType () == shapeTensorTy.getElementType ())
1749+ shape = rewriter.create <tensor::CastOp>(op.getLoc (), opTensorTy, shape);
1750+ else if (!isExtentTensorType (shapeTensorTy))
1751+ shape =
1752+ rewriter.create <arith::IndexCastOp>(op.getLoc (), opTensorTy, shape);
17521753 }
17531754
17541755 rewriter.replaceOp (op, shape);
0 commit comments