@@ -818,27 +818,23 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
818818 // Search the producer slices accessed within the containing operation.
819819 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
820820 // evolve into an interface.
821+ if (bbArg.getUsers ().empty ()) {
822+ diag.attachNote (containingOp->getLoc ())
823+ << " could not find fusion opportunity for bbArg: " << bbArg;
824+ return {};
825+ }
821826 auto itBBArgUsers = llvm::find_if (bbArg.getUsers (), [&](Operation *user) {
822827 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
823828 return sliceOp && containingOp->isProperAncestor (sliceOp);
824829 });
825-
826- // Find a fusion opportunity.
830+ OpBuilder::InsertionGuard guard (rewriter);
831+ tensor::ExtractSliceOp sliceOpToTile;
827832 if (itBBArgUsers == bbArg.getUsers ().end ()) {
828- diag.attachNote (containingOp->getLoc ())
829- << " could not find fusion opportunity for bbArg: " << bbArg;
830- return {};
833+ rewriter.setInsertionPoint (&bbArg.getOwner ()->front ());
834+ } else {
835+ sliceOpToTile = llvm::cast<tensor::ExtractSliceOp>(*itBBArgUsers);
836+ rewriter.setInsertionPoint (sliceOpToTile);
831837 }
832- auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
833-
834- // Try to fuse the producer in-place.
835- OpBuilder::InsertionGuard guard (rewriter);
836- rewriter.setInsertionPoint (sliceOpToTile);
837-
838- // Replace the use in the tileableProducer before tiling: clone, replace and
839- // then tile.
840- int64_t resultNumber = cast<OpResult>(pUse->get ()).getResultNumber ();
841- LLVM_DEBUG (DBGS () << " resultNumber: " << resultNumber << " \n " );
842838
843839 // Gather destination tensors.
844840 SmallVector<Value> destinationTensors;
@@ -850,14 +846,38 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
850846 return {};
851847 }
852848
849+ // Replace the use in the tileableProducer before tiling: clone, replace and
850+ // then tile.
851+ SmallVector<Operation *> oldBbArgUsers (bbArg.getUsers ());
852+ int64_t resultNumber = cast<OpResult>(pUse->get ()).getResultNumber ();
853+ LLVM_DEBUG (DBGS () << " resultNumber: " << resultNumber << " \n " );
853854 IRMapping bvm;
854855 bvm.map (destinationTensors[resultNumber], bbArg);
855856 auto tileableProducerClone =
856857 cast<TilingInterface>(rewriter.clone (*tileableProducer, bvm));
857- auto scopeGuard =
858- llvm::make_scope_exit ([&]() { rewriter.eraseOp (tileableProducerClone); });
858+
859+ // If there was no extract_slice user, then no need to tile.
860+ if (!sliceOpToTile) {
861+ LLVM_DEBUG (DBGS () << " No extract_slice user. No need to tile cloned op.\n " );
862+ // Replace the old uses of bbArg with the cloned op, except for any parallel
863+ // insert ops.
864+ rewriter.replaceUsesWithIf (
865+ bbArg, tileableProducerClone->getResult (resultNumber),
866+ [&](OpOperand &operand) {
867+ return !isa<tensor::ParallelInsertSliceOp>(operand.getOwner ()) &&
868+ operand.getOwner () != tileableProducerClone.getOperation ();
869+ });
870+ // Replace the use in containingOp.
871+ rewriter.modifyOpInPlace (containingOp, [&]() {
872+ containingOp->setOperand (pUse->getOperandNumber (),
873+ destinationTensors.front ());
874+ });
875+ return {tileableProducerClone};
876+ }
859877
860878 // Tile the producer.
879+ auto scopeGuard =
880+ llvm::make_scope_exit ([&]() { rewriter.eraseOp (tileableProducerClone); });
861881 FailureOr<TilingResult> tileAndFuseResult =
862882 tileableProducerClone.generateResultTileValue (
863883 rewriter, resultNumber, sliceOpToTile.getMixedOffsets (),
0 commit comments