1212
1313#include " mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
1414
15+ #include " mlir/Analysis/SliceAnalysis.h"
16+ #include " mlir/Analysis/TopologicalSortUtils.h"
1517#include " mlir/Dialect/Affine/IR/AffineOps.h"
1618#include " mlir/Dialect/Arith/IR/Arith.h"
1719#include " mlir/Dialect/Arith/Utils/Utils.h"
@@ -1580,33 +1582,163 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
15801582 return success ();
15811583}
15821584
1583- // / Fetches the OpOperand of the only user (and use) of the value `val` which
1584- // / implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
1585- // / failure otherwise.
1586- static FailureOr<OpOperand *> getConsumerFromUses (Value val,
1587- Block *containingOpBlock) {
1588- // Check that the value has exactly one use which isn't a scf.yield or a
1589- // tensor.parallel_insert_slice op.
1590- OpOperand *operand = nullptr ;
1591- for (OpOperand &opOperand : val.getUses ()) {
1592- Operation *consumerOp = opOperand.getOwner ();
1593- if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
1594- continue ;
1595- if (operand)
1596- return failure ();
1597- // TODO: We have to init result of consumer before scf.for, use
1598- // DestinationStyleOpInterface to get result shape from init for now.
1599- // Add support for other op such as op has InferTypeOpInterface.
1600- if (!isa<TilingInterface>(consumerOp) ||
1601- !isa<DestinationStyleOpInterface>(consumerOp))
1585+ // / An utility to get the first user of the given loopOp. If any of user stay in
1586+ // / different block of loopOp, return failure.
1587+ static FailureOr<Operation *> getFirstUserOfLoop (Operation *loopOp) {
1588+ if (!isa<LoopLikeOpInterface>(loopOp))
1589+ return failure ();
1590+ Operation *firstUserOfLoop = nullptr ;
1591+ for (Operation *userOp : loopOp->getUsers ()) {
1592+ // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1593+ // block with any other types of operation. Thus, just redirecting to its
1594+ // parent `InParallelOp`. E.g.
1595+ //
1596+ // ```
1597+ // %1 = scf.for {
1598+ // ...
1599+ // }
1600+ // %2 = consumerOp ins(%1, ...)
1601+ // scf.forall.in_parallel {
1602+ // tensor.parallel_insert_slice %1
1603+ // }
1604+ // ```
1605+ // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1606+ // same block with `consumerOp`.
1607+ if (isa<tensor::ParallelInsertSliceOp>(userOp))
1608+ userOp = userOp->getParentOfType <scf::InParallelOp>();
1609+
1610+ if (loopOp->getBlock () != userOp->getBlock ())
16021611 return failure ();
1603- if (containingOpBlock != consumerOp->getBlock ())
1612+
1613+ if (!firstUserOfLoop || userOp->isBeforeInBlock (firstUserOfLoop))
1614+ firstUserOfLoop = userOp;
1615+ }
1616+ return firstUserOfLoop;
1617+ }
1618+
1619+ // / This utility currently checks whether the first userOp of loop is NOT before
1620+ // / the last defineOp of consumer operand. Because that we need to move the
1621+ // / whole loop structure right before the `firstUserOfLoop`. This utility thus
1622+ // / helps ensuring that no invalid IR is formed, i.e. no backward slice of
1623+ // / consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1624+ // /
1625+ // / ```
1626+ // / %0 = scf.for() {
1627+ // / ...
1628+ // / }
1629+ // / ...
1630+ // / %1 = firstUserOfLoop(%0)
1631+ // / ...
1632+ // / %2 = lastDefOfConsumerOperand
1633+ // / ...
1634+ // / %3 = consumerOp(%2)
1635+ // / ```
1636+ // /
1637+ // / If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
1638+ // / be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
1639+ // / use-def chain violation:
1640+ // /
1641+ // / ```
1642+ // / %0:2 = scf.for() {
1643+ // / // use before define error
1644+ // / %3 = tiledConsumerOp(%2)
1645+ // / }
1646+ // / %1 = firstUserOfLoop(%0)
1647+ // / ...
1648+ // / %2 = lastDefOfConsumerOperand
1649+ // / ```
1650+ // /
1651+ // / @param loopOp: loop operation
1652+ // / @param consumerOp: consumer operation
1653+ // / @param reorderOperations: the flag controls whether to reorder the backward
1654+ // / slice w.r.t. the defineOp of `consumerOp` operands.
1655+ // / @return: computed backward slice of consumerOp, but excluding those already
1656+ // / dominates `firstUserOfLoop`.
1657+ static FailureOr<llvm::SetVector<Operation *>>
1658+ checkAssumptionForLoop (Operation *loopOp, Operation *consumerOp,
1659+ bool reorderOperations) {
1660+ FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop (loopOp);
1661+ if (failed (firstUserOfLoop))
1662+ return failure ();
1663+
1664+ BackwardSliceOptions options;
1665+ DominanceInfo dominanceInfo;
1666+ options.inclusive = true ;
1667+ options.omitBlockArguments = true ;
1668+ bool includeLoopOp = false ;
1669+ options.filter = [&](Operation *op) {
1670+ if (op == loopOp) {
1671+ includeLoopOp = true ;
1672+ return false ;
1673+ }
1674+ // Cut off the slice to not include any operation that already dominates
1675+ // firstUserOfLoop.
1676+ return !dominanceInfo.properlyDominates (op, *firstUserOfLoop);
1677+ };
1678+ llvm::SetVector<Operation *> slice;
1679+ for (auto operand : consumerOp->getOperands ()) {
1680+ getBackwardSlice (operand, &slice, options);
1681+ }
1682+
1683+ if (!slice.empty ()) {
1684+ // If consumerOp has one producer, which is also the user of loopOp.
1685+ // E.g.
1686+ // ```
1687+ // %0 = %loopOp
1688+ // %1 = consumerOp1 ins(%0)
1689+ // %2 = consumerOp2 ins(%0, %1)
1690+ // ```
1691+ // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1692+ // consumerOp1 has already been fused into loopOp before.
1693+ if (includeLoopOp || !reorderOperations)
16041694 return failure ();
1605- operand = &opOperand;
16061695 }
16071696
1608- if (operand)
1609- return operand;
1697+ return slice;
1698+ }
1699+
1700+ // / Fetches the OpOperand of the first valid user (and use) of the value `val`
1701+ // / which implements `TilingInterface` and `DestinationStyleOpInterface`.
1702+ // / Returns failure otherwise.
1703+ static FailureOr<OpOperand *> getConsumerFromLoopUses (RewriterBase &rewriter,
1704+ Operation *loopOp,
1705+ unsigned resultNumber) {
1706+ if (!isa<LoopLikeOpInterface>(loopOp))
1707+ return failure ();
1708+ Value val = loopOp->getResult (resultNumber);
1709+ Block *loopBlock = loopOp->getBlock ();
1710+ for (OpOperand &opOperand : val.getUses ()) {
1711+ Operation *consumerOp = opOperand.getOwner ();
1712+ // Step 1. Check if the user is tilable.
1713+ if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
1714+ // TODO: We have to init result of consumer before scf.for, use
1715+ // DestinationStyleOpInterface to get result shape from init for now. Add
1716+ // support for other op such as op has InferTypeOpInterface.
1717+ continue ;
1718+ }
1719+ // Step 2. Check if user stay in the same block.
1720+ if (loopBlock != consumerOp->getBlock ())
1721+ continue ;
1722+ // Step 3. Check if user has succeeding user. Otherwise, it usually
1723+ // represents already tiled.
1724+ if (consumerOp->use_empty ())
1725+ continue ;
1726+ // Step 4. Check assumption for loop with `reorderOperations` enabled.
1727+ FailureOr<llvm::SetVector<Operation *>> slice =
1728+ checkAssumptionForLoop (loopOp, consumerOp, true );
1729+ if (failed (slice))
1730+ continue ;
1731+ // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
1732+ if (!slice->empty ()) {
1733+ mlir::topologicalSort (*slice);
1734+ FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop (loopOp);
1735+ assert (succeeded (firstUserOfLoop) && " First user of loop is not found" );
1736+ for (auto op : *slice) {
1737+ rewriter.moveOpBefore (op, *firstUserOfLoop);
1738+ }
1739+ }
1740+ return &opOperand;
1741+ }
16101742 return failure ();
16111743}
16121744
@@ -1659,7 +1791,8 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
16591791// / 1. tensor.insert_slice has scf.yield as its only user.
16601792// / 2. scf.for's corresponding result has only one use.
16611793static FailureOr<OpOperand *>
1662- getUntiledConsumerFromSlice (tensor::InsertSliceOp candidateSliceOp) {
1794+ getUntiledConsumerFromSlice (RewriterBase &rewriter,
1795+ tensor::InsertSliceOp candidateSliceOp) {
16631796 if (failed (checkAssumptionForFusingConsumer (candidateSliceOp)))
16641797 return failure ();
16651798 Value sliceResult = candidateSliceOp.getResult ();
@@ -1672,15 +1805,15 @@ getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
16721805 if (!forOp)
16731806 return failure ();
16741807 scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf (forOp).front ();
1675- Value resultingValue = topLevelForOp->getResult (resultNumber);
16761808
1677- return getConsumerFromUses (resultingValue , topLevelForOp-> getBlock () );
1809+ return getConsumerFromLoopUses (rewriter , topLevelForOp, resultNumber );
16781810}
16791811
16801812// / Fetch the first untiled consumer of a scf.forall's result which is yielded
16811813// / by a tensor.parallel_insert_slice.
16821814static FailureOr<OpOperand *>
1683- getUntiledConsumerFromSlice (tensor::ParallelInsertSliceOp candidateSliceOp) {
1815+ getUntiledConsumerFromSlice (RewriterBase &rewriter,
1816+ tensor::ParallelInsertSliceOp candidateSliceOp) {
16841817 // Step 1. Fetch the corresponding output
16851818 Value sliceDest = candidateSliceOp.getDest ();
16861819 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
@@ -1693,45 +1826,22 @@ getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
16931826 auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
16941827 if (!forallOp)
16951828 return failure ();
1696- Value resultingValue =
1697- forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg));
1698-
1699- return getConsumerFromUses (resultingValue, containingOp->getBlock ());
1700- }
1829+ unsigned resultNumber =
1830+ forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg))
1831+ .getResultNumber ();
17011832
1702- // / This utility currently checks whether the loop either :-
1703- // / 1. Yields exactly one result.
1704- // / 2. Has consumer op as its first user and other users to be in the same
1705- // / containing block as that of consumer op's. Currently we clone the loop op
1706- // / right before the consumer op in order to maintain a valid def-use chain.
1707- // / This utility thus helps ensuring that no invalid IR is formed due to the
1708- // / same.
1709- static LogicalResult checkAssumptionForLoop (Operation *loopOp,
1710- Operation *consumerOp) {
1711- // Check if the loop op yields one result.
1712- if (loopOp->getNumResults () == 1 )
1713- return success ();
1714- // Check if the consumerOp is the first user of the loopOp and if other users
1715- // are in the same containing block as that of consumer op's.
1716- Block *parentBlock = consumerOp->getBlock ();
1717- for (Operation *userOp : loopOp->getUsers ()) {
1718- if (userOp == consumerOp)
1719- continue ;
1720- if (parentBlock != userOp->getBlock () ||
1721- !consumerOp->isBeforeInBlock (userOp))
1722- return failure ();
1723- }
1724- return success ();
1833+ return getConsumerFromLoopUses (rewriter, containingOp, resultNumber);
17251834}
17261835
17271836// / A utility to fetch an untiled consumer of
17281837// / tensor.insert_slice/tensor.parallel_insert_slice.
1729- static FailureOr<OpOperand *> getUntiledConsumerFromSlice (Operation *sliceOp) {
1838+ static FailureOr<OpOperand *>
1839+ getUntiledConsumerFromSlice (RewriterBase &rewriter, Operation *sliceOp) {
17301840 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1731- return getUntiledConsumerFromSlice (insertSlice);
1841+ return getUntiledConsumerFromSlice (rewriter, insertSlice);
17321842 } else if (auto parallelInsertSlice =
17331843 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1734- return getUntiledConsumerFromSlice (parallelInsertSlice);
1844+ return getUntiledConsumerFromSlice (rewriter, parallelInsertSlice);
17351845 } else {
17361846 return failure ();
17371847 }
@@ -1751,7 +1861,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
17511861 // 1. Get the consumer of scf.for for the result yielded by
17521862 // tensor.insert_slice/parallel_insert_slice.
17531863 FailureOr<OpOperand *> maybeConsumerOpOperand =
1754- getUntiledConsumerFromSlice (candidateSliceOp);
1864+ getUntiledConsumerFromSlice (rewriter, candidateSliceOp);
17551865 if (failed (maybeConsumerOpOperand)) {
17561866 return rewriter.notifyMatchFailure (candidateSliceOp,
17571867 " could not fetch consumer to fuse" );
@@ -1787,11 +1897,11 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
17871897
17881898 LoopLikeOpInterface outerMostLoop = nestedLoops.front ();
17891899
1790- if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp))) {
1900+ // Check assumption for loop with `reorderOperations` disabled.
1901+ if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp, false ))) {
17911902 return rewriter.notifyMatchFailure (
1792- outerMostLoop,
1793- " containing loop op should either yield just one value or "
1794- " have the consumer op as its first user" );
1903+ outerMostLoop, " the first user of loop should not dominate any define "
1904+ " of consumer operand(s)" );
17951905 }
17961906
17971907 OpBuilder::InsertionGuard g (rewriter);
@@ -1812,9 +1922,14 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
18121922
18131923 Location loc = outerMostLoop->getLoc ();
18141924
1815- // 3. Move the whole loop structure right before consumer Op, the dominance
1816- // should be already ensured by `checkAssumptionForLoop`.
1817- rewriter.moveOpBefore (outerMostLoop, consumerOp);
1925+ // 3. Move the whole loop structure right before firstUserOfLoop, the
1926+ // dominance should be already ensured by `checkAssumptionForLoop`.
1927+ FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop (outerMostLoop);
1928+ if (failed (firstUserOfLoop)) {
1929+ return rewriter.notifyMatchFailure (
1930+ outerMostLoop, " could not find the first user of outer most loop" );
1931+ }
1932+ rewriter.moveOpBefore (outerMostLoop, *firstUserOfLoop);
18181933
18191934 // 4. Set insertion point before terminator op of the loop and create a new
18201935 // tensor.insert_slice. In the scf.for case this is a clone of the
0 commit comments