Skip to content

Commit 66fc419

Browse files
[mlir][TilingInterface] Make tileAndFuseConsumerOfSlice take
surrounding loops as an argument. This gets the consumer fusion method in sync with the corresponding producer fusion method `tileAndFuseProducerOfSlice`. Not taking this as input required use of complicated analysis to retrieve the surrounding loops which are very fragile. Just like the producer fusion method, the loops need to be taken in as an argument, with typically the loops being created by the tiling methods. Some utilities are added to check that the loops passed in are perfectly nested (in the case of an `scf.for` loop nest. This is change 1 of N to simplify the implementation of tile and fuse consumers. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent da1c19a commit 66fc419

File tree

2 files changed

+107
-48
lines changed

2 files changed

+107
-48
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ struct SCFFuseConsumerOfSliceResult {
328328
SmallVector<Operation *> tiledOps;
329329
};
330330
FailureOr<scf::SCFFuseConsumerOfSliceResult>
331-
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
331+
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
332+
MutableArrayRef<LoopLikeOpInterface> loops);
332333

333334
/// Method to lower an `op` that implements the `TilingInterface` to
334335
/// loops/scalars.

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 105 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,25 +1890,81 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
18901890
return {nestLoops.rbegin(), nestLoops.rend()};
18911891
}
18921892

1893+
/// Check that the loop is perfectly nested.
1894+
static bool
1895+
isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
1896+
assert(!loops.empty() && "unexpected empty loop nest");
1897+
if (loops.size() == 1) {
1898+
return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1899+
}
1900+
for (auto [outerLoop, innerLoop] :
1901+
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1902+
auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1903+
auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1904+
if (!outerFor || !innerFor) {
1905+
return false;
1906+
}
1907+
auto outerBBArgs = outerFor.getRegionIterArgs();
1908+
auto innerIterArgs = innerFor.getInitArgs();
1909+
if (outerBBArgs.size() != innerIterArgs.size()) {
1910+
return false;
1911+
}
1912+
1913+
for (auto [outerBBArg, innerIterArg] :
1914+
llvm::zip(outerBBArgs, innerIterArgs)) {
1915+
if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1916+
innerIterArg != outerBBArg) {
1917+
return false;
1918+
}
1919+
}
1920+
1921+
auto outerYields =
1922+
cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1923+
auto innerResults = innerFor.getResults();
1924+
if (outerYields.size() != innerResults.size()) {
1925+
return false;
1926+
}
1927+
for (auto [outerYield, innerResult] :
1928+
llvm::zip(outerYields, innerResults)) {
1929+
if (!llvm::hasSingleElement(innerResult.getUses()) ||
1930+
outerYield != innerResult) {
1931+
return false;
1932+
}
1933+
}
1934+
}
1935+
return true;
1936+
}
1937+
18931938
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
18941939
/// tensor.insert_slice. This function makes the following assumptions :
18951940
/// 1. tensor.insert_slice has scf.yield as its only user.
18961941
/// 2. scf.for's corresponding result has only one use.
18971942
static FailureOr<OpOperand *>
18981943
getUntiledConsumerFromSlice(RewriterBase &rewriter,
1899-
tensor::InsertSliceOp candidateSliceOp) {
1944+
tensor::InsertSliceOp candidateSliceOp,
1945+
MutableArrayRef<LoopLikeOpInterface> loops) {
1946+
assert(!loops.empty() && "unexpected loops to be empty");
1947+
// 1. Expect slice to be part of the body of the inner most loop.
1948+
Operation *containingOp = candidateSliceOp->getParentOp();
1949+
if (containingOp != loops.back()) {
1950+
return rewriter.notifyMatchFailure(
1951+
candidateSliceOp,
1952+
"expected slice to be within body of inner-most loop");
1953+
}
1954+
1955+
if (!isPerfectlyNestedForLoops(loops)) {
1956+
return rewriter.notifyMatchFailure(
1957+
candidateSliceOp, "expected passed loops to be perfectly nested.");
1958+
}
1959+
19001960
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
19011961
return failure();
19021962
Value sliceResult = candidateSliceOp.getResult();
19031963
// Step 1. Fetch the corresponding output.
19041964
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
19051965
unsigned resultNumber = yieldOpOperand.getOperandNumber();
1906-
// Step 2. Check containing op is scf.for.
1907-
Operation *containingOp = candidateSliceOp->getParentOp();
1908-
auto forOp = dyn_cast<scf::ForOp>(containingOp);
1909-
if (!forOp)
1910-
return failure();
1911-
scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1966+
1967+
scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
19121968

19131969
return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
19141970
}
@@ -1917,35 +1973,49 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
19171973
/// by a tensor.parallel_insert_slice.
19181974
static FailureOr<OpOperand *>
19191975
getUntiledConsumerFromSlice(RewriterBase &rewriter,
1920-
tensor::ParallelInsertSliceOp candidateSliceOp) {
1921-
// Step 1. Fetch the corresponding output
1976+
tensor::ParallelInsertSliceOp candidateSliceOp,
1977+
MutableArrayRef<LoopLikeOpInterface> loops) {
1978+
assert(!loops.empty() && "unexpected loops to be empty");
1979+
// 1. Check that the surrounding loop is a single scf.forall loop.
1980+
if (loops.size() != 1) {
1981+
return rewriter.notifyMatchFailure(
1982+
candidateSliceOp, "expected single surrounding scf.forall");
1983+
}
1984+
auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
1985+
if (!forallOp) {
1986+
return rewriter.notifyMatchFailure(
1987+
candidateSliceOp, "expected single surrounding scf.forall");
1988+
}
1989+
1990+
// 2. Fetch the corresponding output
19221991
Value sliceDest = candidateSliceOp.getDest();
19231992
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
19241993
if (!iterArg)
19251994
return failure();
1926-
Operation *containingOp = iterArg.getOwner()->getParentOp();
1927-
if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1928-
return failure();
1929-
// Step 2. Check that the containing op is scf.forall.
1930-
auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1931-
if (!forallOp)
1995+
if (iterArg.getOwner()->getParentOp() != forallOp)
19321996
return failure();
1997+
19331998
unsigned resultNumber =
19341999
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
19352000
.getResultNumber();
19362001

1937-
return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
2002+
return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
19382003
}
19392004

19402005
/// A utility to fetch an untiled consumer of
19412006
/// tensor.insert_slice/tensor.parallel_insert_slice.
19422007
static FailureOr<OpOperand *>
1943-
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
2008+
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
2009+
MutableArrayRef<LoopLikeOpInterface> loops) {
2010+
if (loops.empty()) {
2011+
return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops");
2012+
}
2013+
19442014
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1945-
return getUntiledConsumerFromSlice(rewriter, insertSlice);
2015+
return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
19462016
} else if (auto parallelInsertSlice =
19472017
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1948-
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
2018+
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
19492019
} else {
19502020
return failure();
19512021
}
@@ -1954,18 +2024,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
19542024
/// Implementation of fusing consumer of a single slice by computing the
19552025
/// slice of the consumer in-place for scf loop.
19562026
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1957-
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1958-
Operation *candidateSliceOp) {
2027+
mlir::scf::tileAndFuseConsumerOfSlice(
2028+
RewriterBase &rewriter, Operation *candidateSliceOp,
2029+
MutableArrayRef<LoopLikeOpInterface> loops) {
2030+
// Return if `loops` is empty, return an error for now. Caller is expected
2031+
// to handle this case.
2032+
if (loops.empty()) {
2033+
return candidateSliceOp->emitOpError(
2034+
"cannot call tile and fuse consumer with an empty loop nest");
2035+
}
19592036
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
19602037
candidateSliceOp))
19612038
return failure();
19622039

1963-
bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1964-
19652040
// 1. Get the consumer of scf.for for the result yielded by
19662041
// tensor.insert_slice/parallel_insert_slice.
19672042
FailureOr<OpOperand *> maybeConsumerOpOperand =
1968-
getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
2043+
getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
19692044
if (failed(maybeConsumerOpOperand)) {
19702045
return rewriter.notifyMatchFailure(candidateSliceOp,
19712046
"could not fetch consumer to fuse");
@@ -1981,25 +2056,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19812056
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
19822057
}
19832058

1984-
// There are two possible cases regarding `oldLoopOp` here:
1985-
// 1. single `scf.forall` or `scf.for`.
1986-
// 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1987-
// top-level loop is the outer-most one of these nested loops.
1988-
LoopLikeOpInterface innerMostLoop =
1989-
candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1990-
SmallVector<LoopLikeOpInterface> nestedLoops;
1991-
if (isInsertSliceOp) {
1992-
nestedLoops = llvm::map_to_vector(
1993-
getPerfectlyNestedLoopsOutsideOf(
1994-
cast<scf::ForOp>(innerMostLoop.getOperation())),
1995-
[](scf::ForOp forOp) {
1996-
return cast<LoopLikeOpInterface>(forOp.getOperation());
1997-
});
1998-
} else {
1999-
nestedLoops = {innerMostLoop};
2000-
}
2001-
2002-
LoopLikeOpInterface outerMostLoop = nestedLoops.front();
2059+
LoopLikeOpInterface outerMostLoop = loops.front();
2060+
LoopLikeOpInterface innerMostLoop = loops.back();
20032061

20042062
// Check assumption for loop with `reorderOperations` disabled.
20052063
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
@@ -2165,7 +2223,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
21652223
return success();
21662224
};
21672225
// 14. Add new inits to [nested] loops.
2168-
if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
2226+
if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
21692227
newYieldValuesFn))) {
21702228
return rewriter.notifyMatchFailure(tiledConsumerOp,
21712229
"unable to add new inits to nest loop");
@@ -2174,9 +2232,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
21742232
// 15. Replace the result of scf loop and consumer op with new loop's
21752233
// results.
21762234

2177-
for (auto &&[oldResult, newResult] : llvm::zip(
2178-
consumerOp->getResults(),
2179-
nestedLoops.front()->getResults().take_back(newInits.size()))) {
2235+
for (auto &&[oldResult, newResult] :
2236+
llvm::zip(consumerOp->getResults(),
2237+
loops.front()->getResults().take_back(newInits.size()))) {
21802238
rewriter.replaceAllUsesWith(oldResult, newResult);
21812239
}
21822240

0 commit comments

Comments
 (0)