1212
1313#include " mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
1414
15+ #include " mlir/Analysis/SliceAnalysis.h"
16+ #include " mlir/Analysis/TopologicalSortUtils.h"
1517#include " mlir/Dialect/Affine/IR/AffineOps.h"
1618#include " mlir/Dialect/Arith/IR/Arith.h"
1719#include " mlir/Dialect/Arith/Utils/Utils.h"
@@ -1702,7 +1704,7 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
17021704
17031705// / This utility currently checks whether the first userOp of loop is NOT before
17041706// / the last defineOp of consumer. Currently we need to move the loop op right
1705- // / before a certain op in order to maintain a valid def- use chain. This utility
1707+ // / before a certain op in order to maintain a valid use-def chain. This utility
17061708// / thus helps ensuring that no invalid IR is formed. E.g.
17071709// /
17081710// / ```
@@ -1718,20 +1720,22 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
17181720// / ```
17191721// /
17201722// / If the `firstUserOfLoop`is before `lastDefOfConsumer`, then it would be
1721- // / invalid to move the loop op right before the `firstUserOfLoop`:
1723+ // / invalid to move the loop op right before the `firstUserOfLoop`, a.k.a.
1724+ // / use-def chain violation:
17221725// /
17231726// / ```
17241727// / %0:2 = scf.for() {
1728+ // / // use before define error
17251729// / %3 = tiledConsumerOp(%2)
17261730// / }
17271731// / %1 = firstUserOfLoop(%0)
17281732// / ...
17291733// / %2 = lastDefOfConsumer
17301734// / ```
17311735// /
1732- // / To address this issue, this utility would double-check there is no user of
1733- // / `firstUserOfLoop` before `lastDefOfConsumer`. If so, move `firstUserOfLoop`
1734- // / after `lastDefOfConsumer`. Then, it turns out valid as follow:
1736+ // / To address this issue, this utility would try to move `lastDefOfConsumer`
1737+ // / before `firstUserOfLoop` under intrusive mode. Then, it turns out valid as
1738+ // / follow:
17351739// /
17361740// / ```
17371741// / %2 = lastDefOfConsumer
@@ -1741,81 +1745,87 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
17411745// / %1 = firstUserOfLoop(%0)
17421746// / ```
17431747// /
1744- // / Besides, `consumerOp` should not be the user of `firstUserOfLoop`.
1745- // /
17461748// / @param loopOp: loop operation
17471749// / @param consumerOp: consumer operation
1748- // / @param toMoveLoopOpBefore: the operation we move the looOp right before
1749- static LogicalResult checkAssumptionForLoop (Operation *loopOp,
1750+ // / @param firstUserOfLoop: the first user of loopOp, which op we move the looOp
1751+ // / right before
1752+ // / @param intrusive: if true, it allows to move computed slice w.r.t defineOp
1753+ // / of operands of consumerOp. The default value is True. If explicit memory
1754+ // / barrier is required, please turn it off.
1755+ static LogicalResult checkAssumptionForLoop (RewriterBase &rewriter,
1756+ Operation *loopOp,
17501757 Operation *consumerOp,
1751- Operation **toMoveLoopOpBefore) {
1758+ Operation **firstUserOfLoop,
1759+ bool intrusive = true ) {
17521760 Block *parentBlock = consumerOp->getBlock ();
1753- // loopOp and consumerOp should stay in the same block.
1761+ // 1. Check if loopOp and consumerOp stay in the same block.
17541762 if (loopOp->getBlock () != parentBlock)
17551763 return failure ();
17561764
1757- *toMoveLoopOpBefore = nullptr ;
1758- do {
1759- Operation *firstUserOfLoop = consumerOp, *lastDefOfConsumer = loopOp;
1760- // Find the first user of loopOp
1761- for (Operation *userOp : loopOp->getUsers ()) {
1762- if (userOp == consumerOp)
1763- continue ;
1764- // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1765- // block with any other types of operation. Thus, just redirecting to its
1766- // parent `InParallelOp`.
1767- if (isa<tensor::ParallelInsertSliceOp>(userOp))
1768- userOp = userOp->getParentOfType <scf::InParallelOp>();
1765+ *firstUserOfLoop = consumerOp;
1766+ // 2. Find the first user of loopOp.
1767+ for (Operation *userOp : loopOp->getUsers ()) {
1768+ if (userOp == consumerOp)
1769+ continue ;
1770+ // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1771+ // block with any other types of operation. Thus, just redirecting to its
1772+ // parent `InParallelOp`.
1773+ if (isa<tensor::ParallelInsertSliceOp>(userOp))
1774+ userOp = userOp->getParentOfType <scf::InParallelOp>();
17691775
1770- if (parentBlock != userOp->getBlock ())
1771- return failure ();
1776+ if (parentBlock != userOp->getBlock ())
1777+ return failure ();
17721778
1773- if (userOp->isBeforeInBlock (firstUserOfLoop))
1774- firstUserOfLoop = userOp;
1775- }
1779+ if (userOp->isBeforeInBlock (* firstUserOfLoop))
1780+ * firstUserOfLoop = userOp;
1781+ }
17761782
1777- // Find the last define of consumer
1778- for (Value operand : consumerOp->getOperands ()) {
1779- // If the operand is `BlockArgument`, auto skip.
1780- if (isa<BlockArgument>(operand))
1781- continue ;
1782- auto defineOp = operand.getDefiningOp ();
1783- if (!defineOp)
1784- return failure ();
1785- // If defineOp is not in the same block with loopOp, it must dominate the
1786- // loopOp as well. I.e.
1787- // ```
1788- // %a = ...
1789- // {
1790- // %looOp = scf.for
1791- // %b = consumerOp ins(%loopOp, %a)
1792- // }
1793- // ```
1794- if (defineOp == loopOp || parentBlock != defineOp->getBlock ())
1795- continue ;
1796- if (lastDefOfConsumer->isBeforeInBlock (defineOp))
1797- lastDefOfConsumer = defineOp;
1798- }
1799- if (firstUserOfLoop->isBeforeInBlock (lastDefOfConsumer)) {
1800- // Try to move if possible
1801- if (llvm::all_of (firstUserOfLoop->getUsers (),
1802- [&lastDefOfConsumer, &parentBlock](Operation *userOp) {
1803- return userOp->getBlock () == parentBlock &&
1804- lastDefOfConsumer->isBeforeInBlock (userOp);
1805- })) {
1806- // Safely moving
1807- firstUserOfLoop->moveAfter (lastDefOfConsumer);
1808- } else {
1809- return failure ();
1783+ // 3. Find backward slice of defOfConsumer.
1784+ BackwardSliceOptions options;
1785+ DominanceInfo dominanceInfo;
1786+ options.inclusive = true ;
1787+ options.omitBlockArguments = true ;
1788+
1789+ for (auto operand : consumerOp->getOperands ()) {
1790+ llvm::SetVector<Operation *> slice;
1791+ bool includeLoopOp = false ;
1792+ options.filter = [&](Operation *op) {
1793+ if (op == loopOp) {
1794+ includeLoopOp = true ;
1795+ return false ;
1796+ }
1797+ // Cut off the slice to not include any operation that already dominates
1798+ // firstUserOfLoop.
1799+ return !dominanceInfo.properlyDominates (op, *firstUserOfLoop);
1800+ };
1801+ getBackwardSlice (operand, &slice, options);
1802+ if (!slice.empty ()) {
1803+ if (includeLoopOp) {
1804+ // If consumerOp has one producer, which is also the user of loopOp.
1805+ // E.g.
1806+ // ```
1807+ // %0 = %loopOp
1808+ // %1 = consumerOp1 ins(%0)
1809+ // %2 = consumerOp2 ins(%0, %1)
1810+ // ```
1811+ // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1812+ // consumerOp1 has already been fused into loopOp before.
1813+ return rewriter.notifyMatchFailure (
1814+ consumerOp, " could not fuse consumer due to inevitable use-def "
1815+ " chain violation" );
1816+ }
1817+ if (!intrusive) {
1818+ // Please turn on intrusive mode, otherwise just bail out.
1819+ return rewriter.notifyMatchFailure (consumerOp,
1820+ " intrusive mode is not allowed" );
1821+ }
1822+ mlir::topologicalSort (slice);
1823+ // 4. Move all computed slice before firstUserOfLoop.
1824+ for (auto op : slice) {
1825+ rewriter.moveOpBefore (op, *firstUserOfLoop);
18101826 }
1811- } else {
1812- // Check consumerOp is not the user of firstUserOfLoop
1813- if (firstUserOfLoop == lastDefOfConsumer)
1814- return failure ();
1815- // Set InsertPoint
1816- *toMoveLoopOpBefore = firstUserOfLoop;
18171827 }
1818- } while (!(*toMoveLoopOpBefore));
1828+ }
18191829
18201830 return success ();
18211831}
@@ -1884,9 +1894,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
18841894 LoopLikeOpInterface outerMostLoop = nestedLoops.front ();
18851895
18861896 // Find suitable insertPointOp to move the whole loop structure later.
1887- Operation *toMoveLoopOpBefore = nullptr ;
1888- if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp,
1889- &toMoveLoopOpBefore ))) {
1897+ Operation *firstUserOfLoop = nullptr ;
1898+ if (failed (checkAssumptionForLoop (rewriter, outerMostLoop, consumerOp,
1899+ &firstUserOfLoop ))) {
18901900 return rewriter.notifyMatchFailure (
18911901 outerMostLoop,
18921902 " containing loop op should either yield just one value or "
@@ -1913,7 +1923,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19131923
19141924 // 3. Move the whole loop structure right before insertPoint, the dominance
19151925 // should be already ensured by `checkAssumptionForLoop`.
1916- rewriter.moveOpBefore (outerMostLoop, toMoveLoopOpBefore );
1926+ rewriter.moveOpBefore (outerMostLoop, firstUserOfLoop );
19171927
19181928 // 4. Set insertion point before terminator op of the loop and create a new
19191929 // tensor.insert_slice. In the scf.for case this is a clone of the
0 commit comments