Skip to content

Commit d2db005

Browse files
author
Alaa Ali
committed
fix code formatting issue
1 parent 7ff9429 commit d2db005

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,20 +1735,21 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
17351735
// 'shape.shape_of'. While its type is guaranteed to be compatible in well-
17361736
// formed IR, it may not be identical (dynamically vs statically shaped),
17371737
// in which case it needs to be cast first using 'tensor.cast'.
1738-
// Additionally, it may not have identical element type (i32 vs index)
1739-
// while it has identical shaped type (dynamic vs static), in which case it needs
1740-
// to be cast first using 'arith.index_cast'.
1741-
// Note: 'shape.shape_of' op result must be shape or extent tensor.
1738+
// Additionally, it may not have identical element type (i32 vs index)
1739+
// while it has identical shaped type (dynamic vs static), in which case it
1740+
// needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1741+
// op result must be shape or extent tensor.
17421742
Value shape = tensorReshapeOp.getShape();
17431743

17441744
auto opTensorTy = cast<RankedTensorType>(op.getType());
17451745
auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
17461746

17471747
if (opTensorTy != shapeTensorTy) {
1748-
if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1749-
shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
1750-
else if (!isExtentTensorType(shapeTensorTy))
1751-
shape = rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
1748+
if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1749+
shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
1750+
else if (!isExtentTensorType(shapeTensorTy))
1751+
shape =
1752+
rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
17521753
}
17531754

17541755
rewriter.replaceOp(op, shape);

0 commit comments

Comments
 (0)