Skip to content

Conversation

@joker-eph
Copy link
Collaborator

This is more efficient to avoid a clone that is immediately removed.
Also guard the insertion of a cast on the result on whether the destination type changed.

This is more efficient to avoid a clone that is immediately removed.
Also guard the insertion of a cast on the result on whether the destination
type changed.
@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

This is more efficient to avoid a clone that is immediately removed.
Also guard the insertion of a cast on the result on whether the destination type changed.


Full diff: https://github.com/llvm/llvm-project/pull/111593.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+13-10)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 659eabd2e93880..0ac0899def21b5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,21 +4332,24 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
           rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
     }
     Value dest = packOp.getDest();
+    Type originalResultType = dest.getType();
     if (destShape != packOp.getDestType().getShape()) {
       auto newDestType = packOp.getDestType().clone(destShape);
       dest =
           rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
     }
-    auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
-    Value res = clonedPackOp.getResult();
-    rewriter.startOpModification(clonedPackOp);
-    clonedPackOp.getSourceMutable().assign(source);
-    clonedPackOp.getDestMutable().assign(dest);
-    res.setType(dest.getType());
-    rewriter.finalizeOpModification(clonedPackOp);
-
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        packOp, packOp.getResult().getType(), clonedPackOp);
+    rewriter.modifyOpInPlace(packOp, [&] {
+      packOp.getSourceMutable().assign(source);
+      packOp.getDestMutable().assign(dest);
+      packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+    });
+    // Insert a cast if needed
+    if (originalResultType != dest.getType()) {
+      rewriter.setInsertionPointAfter(packOp);
+      auto castOp =
+          rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+      rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+    }
     return success();
   }
 

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both versions look good to me, thanks!

@joker-eph joker-eph merged commit 275a2b0 into llvm:main Oct 9, 2024
6 of 9 checks passed
@joker-eph joker-eph deleted the in-place branch October 9, 2024 07:44
@pashu123
Copy link
Member

pashu123 commented Oct 9, 2024

Thanks, @joker-eph, for the cleanup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants