Skip to content

Commit 0cbb6e7

Browse files
authored
[mlir][scf] Expose isPerfectlyNestedForLoops (#152115)
The function `isPerfectlyNestedForLoops` is useful on its own and so I'm exposing it for downstream use.
1 parent 25bedd0 commit 0cbb6e7

File tree

3 files changed

+46
-57
lines changed

3 files changed

+46
-57
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,14 @@ scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
213213
FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter,
214214
scf::ForallOp forallOp);
215215

216+
/// Check if the provided loops are perfectly nested for-loops. Perfect nesting
217+
/// means:
218+
/// 1. All loops are scf.for operations
219+
/// 2. Each outer loop's region iter args match the inner loop's init args
220+
/// 3. Each outer loop's yields match the inner loop's results
221+
/// 4. Each region iter arg and result has exactly one use
222+
bool isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops);
223+
216224
} // namespace mlir
217225

218226
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,63 +1916,6 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
19161916
return failure();
19171917
}
19181918

1919-
/// Check that the loop is perfectly nested.
1920-
/// The loops are expected to be ordered from outer most to inner most.
1921-
/// For example:
1922-
/// ```
1923-
/// %0 = scf.for()
1924-
/// %1 = scf.for()
1925-
/// %2 = scf.for()
1926-
/// %3 = ...
1927-
/// yield %3
1928-
/// yield %2
1929-
/// yield %1
1930-
/// ```
1931-
/// Here loops should be [%0, %1].
1932-
static bool
1933-
isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
1934-
assert(!loops.empty() && "unexpected empty loop nest");
1935-
if (loops.size() == 1) {
1936-
return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1937-
}
1938-
for (auto [outerLoop, innerLoop] :
1939-
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1940-
auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1941-
auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1942-
if (!outerFor || !innerFor) {
1943-
return false;
1944-
}
1945-
auto outerBBArgs = outerFor.getRegionIterArgs();
1946-
auto innerIterArgs = innerFor.getInitArgs();
1947-
if (outerBBArgs.size() != innerIterArgs.size()) {
1948-
return false;
1949-
}
1950-
1951-
for (auto [outerBBArg, innerIterArg] :
1952-
llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1953-
if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1954-
innerIterArg != outerBBArg) {
1955-
return false;
1956-
}
1957-
}
1958-
1959-
ValueRange outerYields =
1960-
cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1961-
ValueRange innerResults = innerFor.getResults();
1962-
if (outerYields.size() != innerResults.size()) {
1963-
return false;
1964-
}
1965-
for (auto [outerYield, innerResult] :
1966-
llvm::zip_equal(outerYields, innerResults)) {
1967-
if (!llvm::hasSingleElement(innerResult.getUses()) ||
1968-
outerYield != innerResult) {
1969-
return false;
1970-
}
1971-
}
1972-
}
1973-
return true;
1974-
}
1975-
19761919
/// Fetch the untiled consumer of the outermost scf.for's result which is
19771920
/// yielded by a tensor.insert_slice from the innermost scf.for. This function
19781921
/// makes the following assumptions :

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,3 +1512,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
15121512
rewriter.replaceOp(forallOp, normalizedForallOp);
15131513
return normalizedForallOp;
15141514
}
1515+
1516+
bool mlir::isPerfectlyNestedForLoops(
1517+
MutableArrayRef<LoopLikeOpInterface> loops) {
1518+
assert(!loops.empty() && "unexpected empty loop nest");
1519+
if (loops.size() == 1)
1520+
return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1521+
for (auto [outerLoop, innerLoop] :
1522+
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1523+
auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1524+
auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1525+
if (!outerFor || !innerFor)
1526+
return false;
1527+
auto outerBBArgs = outerFor.getRegionIterArgs();
1528+
auto innerIterArgs = innerFor.getInitArgs();
1529+
if (outerBBArgs.size() != innerIterArgs.size())
1530+
return false;
1531+
1532+
for (auto [outerBBArg, innerIterArg] :
1533+
llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1534+
if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1535+
innerIterArg != outerBBArg)
1536+
return false;
1537+
}
1538+
1539+
ValueRange outerYields =
1540+
cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1541+
ValueRange innerResults = innerFor.getResults();
1542+
if (outerYields.size() != innerResults.size())
1543+
return false;
1544+
for (auto [outerYield, innerResult] :
1545+
llvm::zip_equal(outerYields, innerResults)) {
1546+
if (!llvm::hasSingleElement(innerResult.getUses()) ||
1547+
outerYield != innerResult)
1548+
return false;
1549+
}
1550+
}
1551+
return true;
1552+
}

0 commit comments

Comments
 (0)