Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ struct SCFFuseConsumerOfSliceResult {
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops);

/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
Expand Down
152 changes: 105 additions & 47 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1890,25 +1890,81 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
return {nestLoops.rbegin(), nestLoops.rend()};
}

/// Check that the loop is perfectly nested.
static bool
isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected empty loop nest");
if (loops.size() == 1) {
return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
}
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
if (!outerFor || !innerFor) {
return false;
}
auto outerBBArgs = outerFor.getRegionIterArgs();
auto innerIterArgs = innerFor.getInitArgs();
if (outerBBArgs.size() != innerIterArgs.size()) {
return false;
}

for (auto [outerBBArg, innerIterArg] :
llvm::zip(outerBBArgs, innerIterArgs)) {
if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
innerIterArg != outerBBArg) {
return false;
}
}

auto outerYields =
cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
auto innerResults = innerFor.getResults();
if (outerYields.size() != innerResults.size()) {
return false;
}
for (auto [outerYield, innerResult] :
llvm::zip(outerYields, innerResults)) {
if (!llvm::hasSingleElement(innerResult.getUses()) ||
outerYield != innerResult) {
return false;
}
}
}
return true;
}

/// Fetch the untiled consumer of a scf.for's result which is yielded by a
/// tensor.insert_slice. This function makes the following assumptions :
/// 1. tensor.insert_slice has scf.yield as its only user.
/// 2. scf.for's corresponding result has only one use.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
tensor::InsertSliceOp candidateSliceOp) {
tensor::InsertSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected loops to be empty");
// 1. Expect slice to be part of the body of the inner most loop.
Operation *containingOp = candidateSliceOp->getParentOp();
if (containingOp != loops.back()) {
return rewriter.notifyMatchFailure(
candidateSliceOp,
"expected slice to be within body of inner-most loop");
}

if (!isPerfectlyNestedForLoops(loops)) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "expected passed loops to be perfectly nested.");
}

if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
return failure();
Value sliceResult = candidateSliceOp.getResult();
// Step 1. Fetch the corresponding output.
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
unsigned resultNumber = yieldOpOperand.getOperandNumber();
// Step 2. Check containing op is scf.for.
Operation *containingOp = candidateSliceOp->getParentOp();
auto forOp = dyn_cast<scf::ForOp>(containingOp);
if (!forOp)
return failure();
scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();

scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());

return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
}
Expand All @@ -1917,35 +1973,49 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
/// by a tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
tensor::ParallelInsertSliceOp candidateSliceOp) {
// Step 1. Fetch the corresponding output
tensor::ParallelInsertSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected loops to be empty");
// 1. Check that the surrounding loop is a single scf.forall loop.
if (loops.size() != 1) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "expected single surrounding scf.forall");
}
auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
if (!forallOp) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "expected single surrounding scf.forall");
}

// 2. Fetch the corresponding output
Value sliceDest = candidateSliceOp.getDest();
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
if (!iterArg)
return failure();
Operation *containingOp = iterArg.getOwner()->getParentOp();
if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
return failure();
// Step 2. Check that the containing op is scf.forall.
auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
if (!forallOp)
if (iterArg.getOwner()->getParentOp() != forallOp)
return failure();

unsigned resultNumber =
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
.getResultNumber();

return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
}

/// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
if (loops.empty()) {
return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops");
}

if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(rewriter, insertSlice);
return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
} else {
return failure();
}
Expand All @@ -1954,18 +2024,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
Operation *candidateSliceOp) {
mlir::scf::tileAndFuseConsumerOfSlice(
RewriterBase &rewriter, Operation *candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
return candidateSliceOp->emitOpError(
"cannot call tile and fuse consumer with an empty loop nest");
}
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))
return failure();

bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);

// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand =
getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
Expand All @@ -1981,25 +2056,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}

// There are two possible cases regarding `oldLoopOp` here:
// 1. single `scf.forall` or `scf.for`.
// 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
// top-level loop is the outer-most one of these nested loops.
LoopLikeOpInterface innerMostLoop =
candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
SmallVector<LoopLikeOpInterface> nestedLoops;
if (isInsertSliceOp) {
nestedLoops = llvm::map_to_vector(
getPerfectlyNestedLoopsOutsideOf(
cast<scf::ForOp>(innerMostLoop.getOperation())),
[](scf::ForOp forOp) {
return cast<LoopLikeOpInterface>(forOp.getOperation());
});
} else {
nestedLoops = {innerMostLoop};
}

LoopLikeOpInterface outerMostLoop = nestedLoops.front();
LoopLikeOpInterface outerMostLoop = loops.front();
LoopLikeOpInterface innerMostLoop = loops.back();

// Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
Expand Down Expand Up @@ -2165,7 +2223,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
return success();
};
// 14. Add new inits to [nested] loops.
if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
newYieldValuesFn))) {
return rewriter.notifyMatchFailure(tiledConsumerOp,
"unable to add new inits to nest loop");
Expand All @@ -2174,9 +2232,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// 15. Replace the result of scf loop and consumer op with new loop's
// results.

for (auto &&[oldResult, newResult] : llvm::zip(
consumerOp->getResults(),
nestedLoops.front()->getResults().take_back(newInits.size()))) {
for (auto &&[oldResult, newResult] :
llvm::zip(consumerOp->getResults(),
loops.front()->getResults().take_back(newInits.size()))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}

Expand Down
Loading