diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 659eabd2e9388..4d6c5965c4fcc 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4332,21 +4332,25 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { rewriter.create(loc, newSrcType, packOp.getSource()); } Value dest = packOp.getDest(); - if (destShape != packOp.getDestType().getShape()) { + RankedTensorType originalResultType = packOp.getDestType(); + bool needUpdateDestType = (destShape != originalResultType.getShape()); + if (needUpdateDestType) { auto newDestType = packOp.getDestType().clone(destShape); dest = rewriter.create(loc, newDestType, packOp.getDest()); } - auto clonedPackOp = cast(rewriter.clone(*packOp)); - Value res = clonedPackOp.getResult(); - rewriter.startOpModification(clonedPackOp); - clonedPackOp.getSourceMutable().assign(source); - clonedPackOp.getDestMutable().assign(dest); - res.setType(dest.getType()); - rewriter.finalizeOpModification(clonedPackOp); - - rewriter.replaceOpWithNewOp( - packOp, packOp.getResult().getType(), clonedPackOp); + rewriter.modifyOpInPlace(packOp, [&] { + packOp.getSourceMutable().assign(source); + packOp.getDestMutable().assign(dest); + packOp.getResult().setType(cast(dest.getType())); + }); + // Insert a cast if needed + if (needUpdateDestType) { + rewriter.setInsertionPointAfter(packOp); + auto castOp = + rewriter.create(loc, originalResultType, packOp); + rewriter.replaceAllUsesExcept(packOp, castOp, castOp); + } return success(); }