@@ -1890,25 +1890,81 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
18901890 return {nestLoops.rbegin (), nestLoops.rend ()};
18911891}
18921892
1893+ // / Check that the loop is perfectly nested.
1894+ static bool
1895+ isPerfectlyNestedForLoops (MutableArrayRef<LoopLikeOpInterface> loops) {
1896+ assert (!loops.empty () && " unexpected empty loop nest" );
1897+ if (loops.size () == 1 ) {
1898+ return isa_and_nonnull<scf::ForOp>(loops.front ().getOperation ());
1899+ }
1900+ for (auto [outerLoop, innerLoop] :
1901+ llvm::zip_equal (loops.drop_back (), loops.drop_front ())) {
1902+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation ());
1903+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation ());
1904+ if (!outerFor || !innerFor) {
1905+ return false ;
1906+ }
1907+ auto outerBBArgs = outerFor.getRegionIterArgs ();
1908+ auto innerIterArgs = innerFor.getInitArgs ();
1909+ if (outerBBArgs.size () != innerIterArgs.size ()) {
1910+ return false ;
1911+ }
1912+
1913+ for (auto [outerBBArg, innerIterArg] :
1914+ llvm::zip (outerBBArgs, innerIterArgs)) {
1915+ if (!llvm::hasSingleElement (outerBBArg.getUses ()) ||
1916+ innerIterArg != outerBBArg) {
1917+ return false ;
1918+ }
1919+ }
1920+
1921+ auto outerYields =
1922+ cast<scf::YieldOp>(outerFor.getBody ()->getTerminator ())->getOperands ();
1923+ auto innerResults = innerFor.getResults ();
1924+ if (outerYields.size () != innerResults.size ()) {
1925+ return false ;
1926+ }
1927+ for (auto [outerYield, innerResult] :
1928+ llvm::zip (outerYields, innerResults)) {
1929+ if (!llvm::hasSingleElement (innerResult.getUses ()) ||
1930+ outerYield != innerResult) {
1931+ return false ;
1932+ }
1933+ }
1934+ }
1935+ return true ;
1936+ }
1937+
18931938// / Fetch the untiled consumer of a scf.for's result which is yielded by a
18941939// / tensor.insert_slice. This function makes the following assumptions :
18951940// / 1. tensor.insert_slice has scf.yield as its only user.
18961941// / 2. scf.for's corresponding result has only one use.
18971942static FailureOr<OpOperand *>
18981943getUntiledConsumerFromSlice (RewriterBase &rewriter,
1899- tensor::InsertSliceOp candidateSliceOp) {
1944+ tensor::InsertSliceOp candidateSliceOp,
1945+ MutableArrayRef<LoopLikeOpInterface> loops) {
1946+ assert (!loops.empty () && " unexpected loops to be empty" );
1947+ // 1. Expect slice to be part of the body of the inner most loop.
1948+ Operation *containingOp = candidateSliceOp->getParentOp ();
1949+ if (containingOp != loops.back ()) {
1950+ return rewriter.notifyMatchFailure (
1951+ candidateSliceOp,
1952+ " expected slice to be within body of inner-most loop" );
1953+ }
1954+
1955+ if (!isPerfectlyNestedForLoops (loops)) {
1956+ return rewriter.notifyMatchFailure (
1957+ candidateSliceOp, " expected passed loops to be perfectly nested." );
1958+ }
1959+
19001960 if (failed (checkAssumptionForFusingConsumer (candidateSliceOp)))
19011961 return failure ();
19021962 Value sliceResult = candidateSliceOp.getResult ();
19031963 // Step 1. Fetch the corresponding output.
19041964 OpOperand &yieldOpOperand = (*sliceResult.getUses ().begin ());
19051965 unsigned resultNumber = yieldOpOperand.getOperandNumber ();
1906- // Step 2. Check containing op is scf.for.
1907- Operation *containingOp = candidateSliceOp->getParentOp ();
1908- auto forOp = dyn_cast<scf::ForOp>(containingOp);
1909- if (!forOp)
1910- return failure ();
1911- scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf (forOp).front ();
1966+
1967+ scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front ().getOperation ());
19121968
19131969 return getConsumerFromLoopUses (rewriter, topLevelForOp, resultNumber);
19141970}
@@ -1917,35 +1973,49 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
19171973// / by a tensor.parallel_insert_slice.
19181974static FailureOr<OpOperand *>
19191975getUntiledConsumerFromSlice (RewriterBase &rewriter,
1920- tensor::ParallelInsertSliceOp candidateSliceOp) {
1921- // Step 1. Fetch the corresponding output
1976+ tensor::ParallelInsertSliceOp candidateSliceOp,
1977+ MutableArrayRef<LoopLikeOpInterface> loops) {
1978+ assert (!loops.empty () && " unexpected loops to be empty" );
1979+ // 1. Check that the surrounding loop is a single scf.forall loop.
1980+ if (loops.size () != 1 ) {
1981+ return rewriter.notifyMatchFailure (
1982+ candidateSliceOp, " expected single surrounding scf.forall" );
1983+ }
1984+ auto forallOp = dyn_cast<scf::ForallOp>(loops.front ().getOperation ());
1985+ if (!forallOp) {
1986+ return rewriter.notifyMatchFailure (
1987+ candidateSliceOp, " expected single surrounding scf.forall" );
1988+ }
1989+
1990+ // 2. Fetch the corresponding output
19221991 Value sliceDest = candidateSliceOp.getDest ();
19231992 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
19241993 if (!iterArg)
19251994 return failure ();
1926- Operation *containingOp = iterArg.getOwner ()->getParentOp ();
1927- if (containingOp != candidateSliceOp->getParentOp ()->getParentOp ())
1928- return failure ();
1929- // Step 2. Check that the containing op is scf.forall.
1930- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1931- if (!forallOp)
1995+ if (iterArg.getOwner ()->getParentOp () != forallOp)
19321996 return failure ();
1997+
19331998 unsigned resultNumber =
19341999 forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg))
19352000 .getResultNumber ();
19362001
1937- return getConsumerFromLoopUses (rewriter, containingOp , resultNumber);
2002+ return getConsumerFromLoopUses (rewriter, forallOp , resultNumber);
19382003}
19392004
19402005// / A utility to fetch an untiled consumer of
19412006// / tensor.insert_slice/tensor.parallel_insert_slice.
19422007static FailureOr<OpOperand *>
1943- getUntiledConsumerFromSlice (RewriterBase &rewriter, Operation *sliceOp) {
2008+ getUntiledConsumerFromSlice (RewriterBase &rewriter, Operation *sliceOp,
2009+ MutableArrayRef<LoopLikeOpInterface> loops) {
2010+ if (loops.empty ()) {
2011+ return rewriter.notifyMatchFailure (sliceOp, " unexpected empty loops" );
2012+ }
2013+
19442014 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1945- return getUntiledConsumerFromSlice (rewriter, insertSlice);
2015+ return getUntiledConsumerFromSlice (rewriter, insertSlice, loops );
19462016 } else if (auto parallelInsertSlice =
19472017 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1948- return getUntiledConsumerFromSlice (rewriter, parallelInsertSlice);
2018+ return getUntiledConsumerFromSlice (rewriter, parallelInsertSlice, loops );
19492019 } else {
19502020 return failure ();
19512021 }
@@ -1954,18 +2024,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
19542024// / Implementation of fusing consumer of a single slice by computing the
19552025// / slice of the consumer in-place for scf loop.
19562026FailureOr<scf::SCFFuseConsumerOfSliceResult>
1957- mlir::scf::tileAndFuseConsumerOfSlice (RewriterBase &rewriter,
1958- Operation *candidateSliceOp) {
2027+ mlir::scf::tileAndFuseConsumerOfSlice (
2028+ RewriterBase &rewriter, Operation *candidateSliceOp,
2029+ MutableArrayRef<LoopLikeOpInterface> loops) {
2030+ // Return if `loops` is empty, return an error for now. Caller is expected
2031+ // to handle this case.
2032+ if (loops.empty ()) {
2033+ return candidateSliceOp->emitOpError (
2034+ " cannot call tile and fuse consumer with an empty loop nest" );
2035+ }
19592036 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
19602037 candidateSliceOp))
19612038 return failure ();
19622039
1963- bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1964-
19652040 // 1. Get the consumer of scf.for for the result yielded by
19662041 // tensor.insert_slice/parallel_insert_slice.
19672042 FailureOr<OpOperand *> maybeConsumerOpOperand =
1968- getUntiledConsumerFromSlice (rewriter, candidateSliceOp);
2043+ getUntiledConsumerFromSlice (rewriter, candidateSliceOp, loops );
19692044 if (failed (maybeConsumerOpOperand)) {
19702045 return rewriter.notifyMatchFailure (candidateSliceOp,
19712046 " could not fetch consumer to fuse" );
@@ -1981,25 +2056,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19812056 consumerOp, " consumer op's operand doesn't seem to be an OpResult" );
19822057 }
19832058
1984- // There are two possible cases regarding `oldLoopOp` here:
1985- // 1. single `scf.forall` or `scf.for`.
1986- // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1987- // top-level loop is the outer-most one of these nested loops.
1988- LoopLikeOpInterface innerMostLoop =
1989- candidateSliceOp->getParentOfType <LoopLikeOpInterface>();
1990- SmallVector<LoopLikeOpInterface> nestedLoops;
1991- if (isInsertSliceOp) {
1992- nestedLoops = llvm::map_to_vector (
1993- getPerfectlyNestedLoopsOutsideOf (
1994- cast<scf::ForOp>(innerMostLoop.getOperation ())),
1995- [](scf::ForOp forOp) {
1996- return cast<LoopLikeOpInterface>(forOp.getOperation ());
1997- });
1998- } else {
1999- nestedLoops = {innerMostLoop};
2000- }
2001-
2002- LoopLikeOpInterface outerMostLoop = nestedLoops.front ();
2059+ LoopLikeOpInterface outerMostLoop = loops.front ();
2060+ LoopLikeOpInterface innerMostLoop = loops.back ();
20032061
20042062 // Check assumption for loop with `reorderOperations` disabled.
20052063 if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp, false ))) {
@@ -2165,7 +2223,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
21652223 return success ();
21662224 };
21672225 // 14. Add new inits to [nested] loops.
2168- if (failed (addInitOperandsToLoopNest (rewriter, nestedLoops , newInits,
2226+ if (failed (addInitOperandsToLoopNest (rewriter, loops , newInits,
21692227 newYieldValuesFn))) {
21702228 return rewriter.notifyMatchFailure (tiledConsumerOp,
21712229 " unable to add new inits to nest loop" );
@@ -2174,9 +2232,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
21742232 // 15. Replace the result of scf loop and consumer op with new loop's
21752233 // results.
21762234
2177- for (auto &&[oldResult, newResult] : llvm::zip (
2178- consumerOp->getResults (),
2179- nestedLoops .front ()->getResults ().take_back (newInits.size ()))) {
2235+ for (auto &&[oldResult, newResult] :
2236+ llvm::zip ( consumerOp->getResults (),
2237+ loops .front ()->getResults ().take_back (newInits.size ()))) {
21802238 rewriter.replaceAllUsesWith (oldResult, newResult);
21812239 }
21822240
0 commit comments