Skip to content

Commit 9c0d426

Browse files
Address comments.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 66fc419 commit 9c0d426

File tree

5 files changed

+81
-218
lines changed

5 files changed

+81
-218
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 18 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,11 +1846,9 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
18461846
return failure();
18471847
}
18481848

1849-
/// Find the perfectly nested loops outside of given loop(included) sorted
1850-
/// from outer to inner.
1851-
///
1852-
/// E.g.
1853-
///
1849+
/// Check that the loop is perfectly nested.
1850+
/// The loops are expected to be ordered from outer most to inner most.
1851+
/// For example:
18541852
/// ```
18551853
/// %0 = scf.for()
18561854
/// %1 = scf.for()
@@ -1860,37 +1858,7 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
18601858
/// yield %2
18611859
/// yield %1
18621860
/// ```
1863-
///
1864-
/// This function will return three perfectly nested loops: %0 + %1 + %2, when
1865-
/// target inner loop is %2.
1866-
static SmallVector<scf::ForOp>
1867-
getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
1868-
SmallVector<scf::ForOp> nestLoops = {loop};
1869-
auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1870-
1871-
// Check if it is the ForOp that yield the result of inner loop.
1872-
auto isForOpYieldResultOfInnerLoop =
1873-
[](scf::ForOp outerLoop) -> LogicalResult {
1874-
Block *body = outerLoop.getBody();
1875-
if (!llvm::hasSingleElement(body->without_terminator()))
1876-
return failure();
1877-
auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1878-
auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1879-
if (!innerForOp)
1880-
return failure();
1881-
// All of innerForOp results should be yielded.
1882-
return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1883-
};
1884-
1885-
while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1886-
nestLoops.push_back(outerLoop);
1887-
outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1888-
}
1889-
// sorted from outer to inner
1890-
return {nestLoops.rbegin(), nestLoops.rend()};
1891-
}
1892-
1893-
/// Check that the loop is perfectly nested.
1861+
/// Here loops should be [%0, %1].
18941862
static bool
18951863
isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
18961864
assert(!loops.empty() && "unexpected empty loop nest");
@@ -1911,21 +1879,21 @@ isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
19111879
}
19121880

19131881
for (auto [outerBBArg, innerIterArg] :
1914-
llvm::zip(outerBBArgs, innerIterArgs)) {
1882+
llvm::zip_equal(outerBBArgs, innerIterArgs)) {
19151883
if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
19161884
innerIterArg != outerBBArg) {
19171885
return false;
19181886
}
19191887
}
19201888

1921-
auto outerYields =
1889+
ValueRange outerYields =
19221890
cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1923-
auto innerResults = innerFor.getResults();
1891+
ValueRange innerResults = innerFor.getResults();
19241892
if (outerYields.size() != innerResults.size()) {
19251893
return false;
19261894
}
19271895
for (auto [outerYield, innerResult] :
1928-
llvm::zip(outerYields, innerResults)) {
1896+
llvm::zip_equal(outerYields, innerResults)) {
19291897
if (!llvm::hasSingleElement(innerResult.getUses()) ||
19301898
outerYield != innerResult) {
19311899
return false;
@@ -1935,10 +1903,12 @@ isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
19351903
return true;
19361904
}
19371905

1938-
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
1939-
/// tensor.insert_slice. This function makes the following assumptions :
1940-
/// 1. tensor.insert_slice has scf.yield as its only user.
1941-
/// 2. scf.for's corresponding result has only one use.
1906+
/// Fetch the untiled consumer of the outermost scf.for's result which is
1907+
/// yielded by a tensor.insert_slice from the innermost scf.for. This function
1908+
/// makes the following assumptions :
1909+
/// 1. tensor.insert_slice has scf.yield as its only user.
1910+
/// 2. scf.for's corresponding result has only one use.
1911+
/// 3. The `loops` passed in are perfectly nested `scf.for` operations.
19421912
static FailureOr<OpOperand *>
19431913
getUntiledConsumerFromSlice(RewriterBase &rewriter,
19441914
tensor::InsertSliceOp candidateSliceOp,
@@ -1952,6 +1922,7 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
19521922
"expected slice to be within body of inner-most loop");
19531923
}
19541924

1925+
// 2. Check that the loop is perfectly nested.
19551926
if (!isPerfectlyNestedForLoops(loops)) {
19561927
return rewriter.notifyMatchFailure(
19571928
candidateSliceOp, "expected passed loops to be perfectly nested.");
@@ -1960,7 +1931,8 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
19601931
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
19611932
return failure();
19621933
Value sliceResult = candidateSliceOp.getResult();
1963-
// Step 1. Fetch the corresponding output.
1934+
1935+
// 3. Fetch the corresponding output.
19641936
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
19651937
unsigned resultNumber = yieldOpOperand.getOperandNumber();
19661938

@@ -2007,10 +1979,7 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
20071979
static FailureOr<OpOperand *>
20081980
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
20091981
MutableArrayRef<LoopLikeOpInterface> loops) {
2010-
if (loops.empty()) {
2011-
return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops");
2012-
}
2013-
1982+
assert(!loops.empty() && "unexpected empty loops");
20141983
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
20151984
return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
20161985
} else if (auto parallelInsertSlice =

mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ module {
170170
// Fuse the consumer operation into the tiled loop.
171171
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
172172
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
173-
transform.test.fuse_consumer %slice_op
174-
: (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op)
173+
transform.test.fuse_consumer %slice_op in (%forall_op)
174+
: (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
175175
transform.yield
176176
}
177177
}

0 commit comments

Comments
 (0)