@@ -1846,11 +1846,9 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
18461846 return failure ();
18471847}
18481848
1849- // / Find the perfectly nested loops outside of given loop(included) sorted
1850- // / from outer to inner.
1851- // /
1852- // / E.g.
1853- // /
1849+ // / Check that the loop is perfectly nested.
1850+ // / The loops are expected to be ordered from outer most to inner most.
1851+ // / For example:
18541852// / ```
18551853// / %0 = scf.for()
18561854// / %1 = scf.for()
@@ -1860,37 +1858,7 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
18601858// / yield %2
18611859// / yield %1
18621860// / ```
1863- // /
1864- // / This function will return three perfectly nested loops: %0 + %1 + %2, when
1865- // / target inner loop is %2.
1866- static SmallVector<scf::ForOp>
1867- getPerfectlyNestedLoopsOutsideOf (scf::ForOp loop) {
1868- SmallVector<scf::ForOp> nestLoops = {loop};
1869- auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp ());
1870-
1871- // Check if it is the ForOp that yield the result of inner loop.
1872- auto isForOpYieldResultOfInnerLoop =
1873- [](scf::ForOp outerLoop) -> LogicalResult {
1874- Block *body = outerLoop.getBody ();
1875- if (!llvm::hasSingleElement (body->without_terminator ()))
1876- return failure ();
1877- auto yieldOp = cast<scf::YieldOp>(body->getTerminator ());
1878- auto innerForOp = dyn_cast<scf::ForOp>(body->front ());
1879- if (!innerForOp)
1880- return failure ();
1881- // All of innerForOp results should be yielded.
1882- return success (innerForOp->getNumResults () == yieldOp->getNumOperands ());
1883- };
1884-
1885- while (outerLoop && succeeded (isForOpYieldResultOfInnerLoop (outerLoop))) {
1886- nestLoops.push_back (outerLoop);
1887- outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp ());
1888- }
1889- // sorted from outer to inner
1890- return {nestLoops.rbegin (), nestLoops.rend ()};
1891- }
1892-
1893- // / Check that the loop is perfectly nested.
1861+ // / Here loops should be [%0, %1].
18941862static bool
18951863isPerfectlyNestedForLoops (MutableArrayRef<LoopLikeOpInterface> loops) {
18961864 assert (!loops.empty () && " unexpected empty loop nest" );
@@ -1911,21 +1879,21 @@ isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
19111879 }
19121880
19131881 for (auto [outerBBArg, innerIterArg] :
1914- llvm::zip (outerBBArgs, innerIterArgs)) {
1882+ llvm::zip_equal (outerBBArgs, innerIterArgs)) {
19151883 if (!llvm::hasSingleElement (outerBBArg.getUses ()) ||
19161884 innerIterArg != outerBBArg) {
19171885 return false ;
19181886 }
19191887 }
19201888
1921- auto outerYields =
1889+ ValueRange outerYields =
19221890 cast<scf::YieldOp>(outerFor.getBody ()->getTerminator ())->getOperands ();
1923- auto innerResults = innerFor.getResults ();
1891+ ValueRange innerResults = innerFor.getResults ();
19241892 if (outerYields.size () != innerResults.size ()) {
19251893 return false ;
19261894 }
19271895 for (auto [outerYield, innerResult] :
1928- llvm::zip (outerYields, innerResults)) {
1896+ llvm::zip_equal (outerYields, innerResults)) {
19291897 if (!llvm::hasSingleElement (innerResult.getUses ()) ||
19301898 outerYield != innerResult) {
19311899 return false ;
@@ -1935,10 +1903,12 @@ isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
19351903 return true ;
19361904}
19371905
1938- // / Fetch the untiled consumer of a scf.for's result which is yielded by a
1939- // / tensor.insert_slice. This function makes the following assumptions :
1940- // / 1. tensor.insert_slice has scf.yield as its only user.
1941- // / 2. scf.for's corresponding result has only one use.
1906+ // / Fetch the untiled consumer of the outermost scf.for's result which is
1907+ // / yielded by a tensor.insert_slice from the innermost scf.for. This function
1908+ // / makes the following assumptions :
1909+ // / 1. tensor.insert_slice has scf.yield as its only user.
1910+ // / 2. scf.for's corresponding result has only one use.
1911+ // / 3. The `loops` passed in are perfectly nested `scf.for` operations.
19421912static FailureOr<OpOperand *>
19431913getUntiledConsumerFromSlice (RewriterBase &rewriter,
19441914 tensor::InsertSliceOp candidateSliceOp,
@@ -1952,6 +1922,7 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
19521922 " expected slice to be within body of inner-most loop" );
19531923 }
19541924
1925+ // 2. Check that the loop is perfectly nested.
19551926 if (!isPerfectlyNestedForLoops (loops)) {
19561927 return rewriter.notifyMatchFailure (
19571928 candidateSliceOp, " expected passed loops to be perfectly nested." );
@@ -1960,7 +1931,8 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
19601931 if (failed (checkAssumptionForFusingConsumer (candidateSliceOp)))
19611932 return failure ();
19621933 Value sliceResult = candidateSliceOp.getResult ();
1963- // Step 1. Fetch the corresponding output.
1934+
1935+ // 3. Fetch the corresponding output.
19641936 OpOperand &yieldOpOperand = (*sliceResult.getUses ().begin ());
19651937 unsigned resultNumber = yieldOpOperand.getOperandNumber ();
19661938
@@ -2007,10 +1979,7 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
20071979static FailureOr<OpOperand *>
20081980getUntiledConsumerFromSlice (RewriterBase &rewriter, Operation *sliceOp,
20091981 MutableArrayRef<LoopLikeOpInterface> loops) {
2010- if (loops.empty ()) {
2011- return rewriter.notifyMatchFailure (sliceOp, " unexpected empty loops" );
2012- }
2013-
1982+ assert (!loops.empty () && " unexpected empty loops" );
20141983 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
20151984 return getUntiledConsumerFromSlice (rewriter, insertSlice, loops);
20161985 } else if (auto parallelInsertSlice =
0 commit comments