@@ -1846,11 +1846,9 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
18461846 return failure ();
18471847}
18481848
1849- // / Find the perfectly nested loops outside of given loop(included) sorted
1850- // / from outer to inner.
1851- // /
1852- // / E.g.
1853- // /
1849+ // / Check that the loop is perfectly nested.
1850+ // / The loops are expected to be ordered from outer most to inner most.
1851+ // / For example:
18541852// / ```
18551853// / %0 = scf.for()
18561854// / %1 = scf.for()
@@ -1860,55 +1858,85 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
18601858// / yield %2
18611859// / yield %1
18621860// / ```
1863- // /
1864- // / This function will return three perfectly nested loops: %0 + %1 + %2, when
1865- // / target inner loop is %2.
1866- static SmallVector<scf::ForOp>
1867- getPerfectlyNestedLoopsOutsideOf (scf::ForOp loop) {
1868- SmallVector<scf::ForOp> nestLoops = {loop};
1869- auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp ());
1870-
1871- // Check if it is the ForOp that yield the result of inner loop.
1872- auto isForOpYieldResultOfInnerLoop =
1873- [](scf::ForOp outerLoop) -> LogicalResult {
1874- Block *body = outerLoop.getBody ();
1875- if (!llvm::hasSingleElement (body->without_terminator ()))
1876- return failure ();
1877- auto yieldOp = cast<scf::YieldOp>(body->getTerminator ());
1878- auto innerForOp = dyn_cast<scf::ForOp>(body->front ());
1879- if (!innerForOp)
1880- return failure ();
1881- // All of innerForOp results should be yielded.
1882- return success (innerForOp->getNumResults () == yieldOp->getNumOperands ());
1883- };
1861+ // / Here loops should be [%0, %1].
1862+ static bool
1863+ isPerfectlyNestedForLoops (MutableArrayRef<LoopLikeOpInterface> loops) {
1864+ assert (!loops.empty () && " unexpected empty loop nest" );
1865+ if (loops.size () == 1 ) {
1866+ return isa_and_nonnull<scf::ForOp>(loops.front ().getOperation ());
1867+ }
1868+ for (auto [outerLoop, innerLoop] :
1869+ llvm::zip_equal (loops.drop_back (), loops.drop_front ())) {
1870+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation ());
1871+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation ());
1872+ if (!outerFor || !innerFor) {
1873+ return false ;
1874+ }
1875+ auto outerBBArgs = outerFor.getRegionIterArgs ();
1876+ auto innerIterArgs = innerFor.getInitArgs ();
1877+ if (outerBBArgs.size () != innerIterArgs.size ()) {
1878+ return false ;
1879+ }
1880+
1881+ for (auto [outerBBArg, innerIterArg] :
1882+ llvm::zip_equal (outerBBArgs, innerIterArgs)) {
1883+ if (!llvm::hasSingleElement (outerBBArg.getUses ()) ||
1884+ innerIterArg != outerBBArg) {
1885+ return false ;
1886+ }
1887+ }
18841888
1885- while (outerLoop && succeeded (isForOpYieldResultOfInnerLoop (outerLoop))) {
1886- nestLoops.push_back (outerLoop);
1887- outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp ());
1889+ ValueRange outerYields =
1890+ cast<scf::YieldOp>(outerFor.getBody ()->getTerminator ())->getOperands ();
1891+ ValueRange innerResults = innerFor.getResults ();
1892+ if (outerYields.size () != innerResults.size ()) {
1893+ return false ;
1894+ }
1895+ for (auto [outerYield, innerResult] :
1896+ llvm::zip_equal (outerYields, innerResults)) {
1897+ if (!llvm::hasSingleElement (innerResult.getUses ()) ||
1898+ outerYield != innerResult) {
1899+ return false ;
1900+ }
1901+ }
18881902 }
1889- // sorted from outer to inner
1890- return {nestLoops.rbegin (), nestLoops.rend ()};
1903+ return true ;
18911904}
18921905
1893- // / Fetch the untiled consumer of a scf.for's result which is yielded by a
1894- // / tensor.insert_slice. This function makes the following assumptions :
1895- // / 1. tensor.insert_slice has scf.yield as its only user.
1896- // / 2. scf.for's corresponding result has only one use.
1906+ // / Fetch the untiled consumer of the outermost scf.for's result which is
1907+ // / yielded by a tensor.insert_slice from the innermost scf.for. This function
1908+ // / makes the following assumptions :
1909+ // / 1. tensor.insert_slice has scf.yield as its only user.
1910+ // / 2. scf.for's corresponding result has only one use.
1911+ // / 3. The `loops` passed in are perfectly nested `scf.for` operations.
18971912static FailureOr<OpOperand *>
18981913getUntiledConsumerFromSlice (RewriterBase &rewriter,
1899- tensor::InsertSliceOp candidateSliceOp) {
1914+ tensor::InsertSliceOp candidateSliceOp,
1915+ MutableArrayRef<LoopLikeOpInterface> loops) {
1916+ assert (!loops.empty () && " unexpected loops to be empty" );
1917+ // 1. Expect slice to be part of the body of the inner most loop.
1918+ Operation *containingOp = candidateSliceOp->getParentOp ();
1919+ if (containingOp != loops.back ()) {
1920+ return rewriter.notifyMatchFailure (
1921+ candidateSliceOp,
1922+ " expected slice to be within body of inner-most loop" );
1923+ }
1924+
1925+ // 2. Check that the loop is perfectly nested.
1926+ if (!isPerfectlyNestedForLoops (loops)) {
1927+ return rewriter.notifyMatchFailure (
1928+ candidateSliceOp, " expected passed loops to be perfectly nested." );
1929+ }
1930+
19001931 if (failed (checkAssumptionForFusingConsumer (candidateSliceOp)))
19011932 return failure ();
19021933 Value sliceResult = candidateSliceOp.getResult ();
1903- // Step 1. Fetch the corresponding output.
1934+
1935+ // 3. Fetch the corresponding output.
19041936 OpOperand &yieldOpOperand = (*sliceResult.getUses ().begin ());
19051937 unsigned resultNumber = yieldOpOperand.getOperandNumber ();
1906- // Step 2. Check containing op is scf.for.
1907- Operation *containingOp = candidateSliceOp->getParentOp ();
1908- auto forOp = dyn_cast<scf::ForOp>(containingOp);
1909- if (!forOp)
1910- return failure ();
1911- scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf (forOp).front ();
1938+
1939+ scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front ().getOperation ());
19121940
19131941 return getConsumerFromLoopUses (rewriter, topLevelForOp, resultNumber);
19141942}
@@ -1917,35 +1945,46 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
19171945// / by a tensor.parallel_insert_slice.
19181946static FailureOr<OpOperand *>
19191947getUntiledConsumerFromSlice (RewriterBase &rewriter,
1920- tensor::ParallelInsertSliceOp candidateSliceOp) {
1921- // Step 1. Fetch the corresponding output
1948+ tensor::ParallelInsertSliceOp candidateSliceOp,
1949+ MutableArrayRef<LoopLikeOpInterface> loops) {
1950+ assert (!loops.empty () && " unexpected loops to be empty" );
1951+ // 1. Check that the surrounding loop is a single scf.forall loop.
1952+ if (loops.size () != 1 ) {
1953+ return rewriter.notifyMatchFailure (
1954+ candidateSliceOp, " expected single surrounding scf.forall" );
1955+ }
1956+ auto forallOp = dyn_cast<scf::ForallOp>(loops.front ().getOperation ());
1957+ if (!forallOp) {
1958+ return rewriter.notifyMatchFailure (
1959+ candidateSliceOp, " expected single surrounding scf.forall" );
1960+ }
1961+
1962+ // 2. Fetch the corresponding output
19221963 Value sliceDest = candidateSliceOp.getDest ();
19231964 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
19241965 if (!iterArg)
19251966 return failure ();
1926- Operation *containingOp = iterArg.getOwner ()->getParentOp ();
1927- if (containingOp != candidateSliceOp->getParentOp ()->getParentOp ())
1928- return failure ();
1929- // Step 2. Check that the containing op is scf.forall.
1930- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1931- if (!forallOp)
1967+ if (iterArg.getOwner ()->getParentOp () != forallOp)
19321968 return failure ();
1969+
19331970 unsigned resultNumber =
19341971 forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg))
19351972 .getResultNumber ();
19361973
1937- return getConsumerFromLoopUses (rewriter, containingOp , resultNumber);
1974+ return getConsumerFromLoopUses (rewriter, forallOp , resultNumber);
19381975}
19391976
19401977// / A utility to fetch an untiled consumer of
19411978// / tensor.insert_slice/tensor.parallel_insert_slice.
19421979static FailureOr<OpOperand *>
1943- getUntiledConsumerFromSlice (RewriterBase &rewriter, Operation *sliceOp) {
1980+ getUntiledConsumerFromSlice (RewriterBase &rewriter, Operation *sliceOp,
1981+ MutableArrayRef<LoopLikeOpInterface> loops) {
1982+ assert (!loops.empty () && " unexpected empty loops" );
19441983 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1945- return getUntiledConsumerFromSlice (rewriter, insertSlice);
1984+ return getUntiledConsumerFromSlice (rewriter, insertSlice, loops );
19461985 } else if (auto parallelInsertSlice =
19471986 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1948- return getUntiledConsumerFromSlice (rewriter, parallelInsertSlice);
1987+ return getUntiledConsumerFromSlice (rewriter, parallelInsertSlice, loops );
19491988 } else {
19501989 return failure ();
19511990 }
@@ -1954,18 +1993,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
19541993// / Implementation of fusing consumer of a single slice by computing the
19551994// / slice of the consumer in-place for scf loop.
19561995FailureOr<scf::SCFFuseConsumerOfSliceResult>
1957- mlir::scf::tileAndFuseConsumerOfSlice (RewriterBase &rewriter,
1958- Operation *candidateSliceOp) {
1996+ mlir::scf::tileAndFuseConsumerOfSlice (
1997+ RewriterBase &rewriter, Operation *candidateSliceOp,
1998+ MutableArrayRef<LoopLikeOpInterface> loops) {
1999+ // Return if `loops` is empty, return an error for now. Caller is expected
2000+ // to handle this case.
2001+ if (loops.empty ()) {
2002+ return candidateSliceOp->emitOpError (
2003+ " cannot call tile and fuse consumer with an empty loop nest" );
2004+ }
19592005 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
19602006 candidateSliceOp))
19612007 return failure ();
19622008
1963- bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1964-
19652009 // 1. Get the consumer of scf.for for the result yielded by
19662010 // tensor.insert_slice/parallel_insert_slice.
19672011 FailureOr<OpOperand *> maybeConsumerOpOperand =
1968- getUntiledConsumerFromSlice (rewriter, candidateSliceOp);
2012+ getUntiledConsumerFromSlice (rewriter, candidateSliceOp, loops );
19692013 if (failed (maybeConsumerOpOperand)) {
19702014 return rewriter.notifyMatchFailure (candidateSliceOp,
19712015 " could not fetch consumer to fuse" );
@@ -1981,25 +2025,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19812025 consumerOp, " consumer op's operand doesn't seem to be an OpResult" );
19822026 }
19832027
1984- // There are two possible cases regarding `oldLoopOp` here:
1985- // 1. single `scf.forall` or `scf.for`.
1986- // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1987- // top-level loop is the outer-most one of these nested loops.
1988- LoopLikeOpInterface innerMostLoop =
1989- candidateSliceOp->getParentOfType <LoopLikeOpInterface>();
1990- SmallVector<LoopLikeOpInterface> nestedLoops;
1991- if (isInsertSliceOp) {
1992- nestedLoops = llvm::map_to_vector (
1993- getPerfectlyNestedLoopsOutsideOf (
1994- cast<scf::ForOp>(innerMostLoop.getOperation ())),
1995- [](scf::ForOp forOp) {
1996- return cast<LoopLikeOpInterface>(forOp.getOperation ());
1997- });
1998- } else {
1999- nestedLoops = {innerMostLoop};
2000- }
2001-
2002- LoopLikeOpInterface outerMostLoop = nestedLoops.front ();
2028+ LoopLikeOpInterface outerMostLoop = loops.front ();
2029+ LoopLikeOpInterface innerMostLoop = loops.back ();
20032030
20042031 // Check assumption for loop with `reorderOperations` disabled.
20052032 if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp, false ))) {
@@ -2165,7 +2192,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
21652192 return success ();
21662193 };
21672194 // 14. Add new inits to [nested] loops.
2168- if (failed (addInitOperandsToLoopNest (rewriter, nestedLoops , newInits,
2195+ if (failed (addInitOperandsToLoopNest (rewriter, loops , newInits,
21692196 newYieldValuesFn))) {
21702197 return rewriter.notifyMatchFailure (tiledConsumerOp,
21712198 " unable to add new inits to nest loop" );
@@ -2174,9 +2201,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
21742201 // 15. Replace the result of scf loop and consumer op with new loop's
21752202 // results.
21762203
2177- for (auto &&[oldResult, newResult] : llvm::zip (
2178- consumerOp->getResults (),
2179- nestedLoops .front ()->getResults ().take_back (newInits.size ()))) {
2204+ for (auto &&[oldResult, newResult] :
2205+ llvm::zip ( consumerOp->getResults (),
2206+ loops .front ()->getResults ().take_back (newInits.size ()))) {
21802207 rewriter.replaceAllUsesWith (oldResult, newResult);
21812208 }
21822209
0 commit comments