@@ -1800,16 +1800,16 @@ static LogicalResult checkAssumptionForLoop(RewriterBase &rewriter,
18001800 };
18011801 getBackwardSlice (operand, &slice, options);
18021802 if (!slice.empty ()) {
1803+ // If consumerOp has one producer, which is also the user of loopOp.
1804+ // E.g.
1805+ // ```
1806+ // %0 = %loopOp
1807+ // %1 = consumerOp1 ins(%0)
1808+ // %2 = consumerOp2 ins(%0, %1)
1809+ // ```
1810+ // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1811+ // consumerOp1 has already been fused into loopOp before.
18031812 if (includeLoopOp) {
1804- // If consumerOp has one producer, which is also the user of loopOp.
1805- // E.g.
1806- // ```
1807- // %0 = %loopOp
1808- // %1 = consumerOp1 ins(%0)
1809- // %2 = consumerOp2 ins(%0, %1)
1810- // ```
1811- // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1812- // consumerOp1 has already been fused into loopOp before.
18131813 return rewriter.notifyMatchFailure (
18141814 consumerOp, " could not fuse consumer due to inevitable use-def "
18151815 " chain violation" );
@@ -1843,6 +1843,24 @@ static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
18431843 }
18441844}
18451845
1846+ // / A utility to move the given operand to the end of use list.
1847+ static void moveOperandToEndOfUseList (OpOperand *operand) {
1848+ Value::use_range uses = operand->get ().getUses ();
1849+ size_t numberUses = std::distance (uses.begin (), uses.end ());
1850+ if (numberUses == 1 )
1851+ return ;
1852+ auto iter = llvm::find (uses, *operand);
1853+ if (iter == uses.end ())
1854+ return ;
1855+ unsigned index = std::distance (uses.begin (), iter);
1856+ SmallVector<unsigned > indices =
1857+ llvm::to_vector (llvm::seq<unsigned >(numberUses));
1858+ indices.push_back (indices[index]);
1859+ indices.erase (indices.begin () + index);
1860+ operand->get ().shuffleUseList (indices);
1861+ return ;
1862+ }
1863+
18461864// / Implementation of fusing consumer of a single slice by computing the
18471865// / slice of the consumer in-place for scf loop.
18481866FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1897,6 +1915,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
18971915 Operation *firstUserOfLoop = nullptr ;
18981916 if (failed (checkAssumptionForLoop (rewriter, outerMostLoop, consumerOp,
18991917 &firstUserOfLoop))) {
1918+ // Prepare for next consumer.
1919+ moveOperandToEndOfUseList (consumerOpOperand);
19001920 return rewriter.notifyMatchFailure (
19011921 outerMostLoop,
19021922 " containing loop op should either yield just one value or "
0 commit comments