-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][TilingInterface] Make tileAndFuseConsumerOfSlice take surrounding loops as an argument.
#132082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][TilingInterface] Make tileAndFuseConsumerOfSlice take surrounding loops as an argument.
#132082
Conversation
surrounding loops as an argument. This gets the consumer fusion method in sync with the corresponding producer fusion method `tileAndFuseProducerOfSlice`. Not taking this as input required use of complicated analysis to retrieve the surrounding loops which are very fragile. Just like the producer fusion method, the loops need to be taken in as an argument, with typically the loops being created by the tiling methods. Some utilities are added to check that the loops passed in are perfectly nested (in the case of an `scf.for` loop nest. This is change 1 of N to simplify the implementation of tile and fuse consumers. Signed-off-by: MaheshRavishankar <[email protected]>
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (MaheshRavishankar) ChangesThis gets the consumer fusion method in sync with the corresponding producer fusion method Some utilities are added to check that the loops passed in are perfectly nested (in the case of an This is change 1 of N to simplify the implementation of tile and fuse consumers. Full diff: https://github.com/llvm/llvm-project/pull/132082.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index d2cddfe00ac78..33a43ce2ee7bb 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -328,7 +328,8 @@ struct SCFFuseConsumerOfSliceResult {
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops);
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index af87fb7a79d04..4fd10b0e30ab0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1890,25 +1890,81 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
return {nestLoops.rbegin(), nestLoops.rend()};
}
+/// Check that the loop is perfectly nested.
+static bool
+isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected empty loop nest");
+ if (loops.size() == 1) {
+ return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
+ }
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
+ if (!outerFor || !innerFor) {
+ return false;
+ }
+ auto outerBBArgs = outerFor.getRegionIterArgs();
+ auto innerIterArgs = innerFor.getInitArgs();
+ if (outerBBArgs.size() != innerIterArgs.size()) {
+ return false;
+ }
+
+ for (auto [outerBBArg, innerIterArg] :
+ llvm::zip(outerBBArgs, innerIterArgs)) {
+ if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
+ innerIterArg != outerBBArg) {
+ return false;
+ }
+ }
+
+ auto outerYields =
+ cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
+ auto innerResults = innerFor.getResults();
+ if (outerYields.size() != innerResults.size()) {
+ return false;
+ }
+ for (auto [outerYield, innerResult] :
+ llvm::zip(outerYields, innerResults)) {
+ if (!llvm::hasSingleElement(innerResult.getUses()) ||
+ outerYield != innerResult) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
/// tensor.insert_slice. This function makes the following assumptions :
/// 1. tensor.insert_slice has scf.yield as its only user.
/// 2. scf.for's corresponding result has only one use.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
- tensor::InsertSliceOp candidateSliceOp) {
+ tensor::InsertSliceOp candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected loops to be empty");
+ // 1. Expect slice to be part of the body of the inner most loop.
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ if (containingOp != loops.back()) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp,
+ "expected slice to be within body of inner-most loop");
+ }
+
+ if (!isPerfectlyNestedForLoops(loops)) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "expected passed loops to be perfectly nested.");
+ }
+
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
return failure();
Value sliceResult = candidateSliceOp.getResult();
// Step 1. Fetch the corresponding output.
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
unsigned resultNumber = yieldOpOperand.getOperandNumber();
- // Step 2. Check containing op is scf.for.
- Operation *containingOp = candidateSliceOp->getParentOp();
- auto forOp = dyn_cast<scf::ForOp>(containingOp);
- if (!forOp)
- return failure();
- scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
+
+ scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
}
@@ -1917,35 +1973,49 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
/// by a tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
- tensor::ParallelInsertSliceOp candidateSliceOp) {
- // Step 1. Fetch the corresponding output
+ tensor::ParallelInsertSliceOp candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected loops to be empty");
+ // 1. Check that the surrounding loop is a single scf.forall loop.
+ if (loops.size() != 1) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "expected single surrounding scf.forall");
+ }
+ auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
+ if (!forallOp) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "expected single surrounding scf.forall");
+ }
+
+ // 2. Fetch the corresponding output
Value sliceDest = candidateSliceOp.getDest();
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
if (!iterArg)
return failure();
- Operation *containingOp = iterArg.getOwner()->getParentOp();
- if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
- return failure();
- // Step 2. Check that the containing op is scf.forall.
- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
- if (!forallOp)
+ if (iterArg.getOwner()->getParentOp() != forallOp)
return failure();
+
unsigned resultNumber =
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
.getResultNumber();
- return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
+ return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
}
/// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
+getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ if (loops.empty()) {
+ return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops");
+ }
+
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(rewriter, insertSlice);
+ return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
+ return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
} else {
return failure();
}
@@ -1954,18 +2024,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
- Operation *candidateSliceOp) {
+mlir::scf::tileAndFuseConsumerOfSlice(
+ RewriterBase &rewriter, Operation *candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ // Return if `loops` is empty, return an error for now. Caller is expected
+ // to handle this case.
+ if (loops.empty()) {
+ return candidateSliceOp->emitOpError(
+ "cannot call tile and fuse consumer with an empty loop nest");
+ }
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))
return failure();
- bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
-
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand =
- getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
+ getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
@@ -1981,25 +2056,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
- // There are two possible cases regarding `oldLoopOp` here:
- // 1. single `scf.forall` or `scf.for`.
- // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
- // top-level loop is the outer-most one of these nested loops.
- LoopLikeOpInterface innerMostLoop =
- candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
- SmallVector<LoopLikeOpInterface> nestedLoops;
- if (isInsertSliceOp) {
- nestedLoops = llvm::map_to_vector(
- getPerfectlyNestedLoopsOutsideOf(
- cast<scf::ForOp>(innerMostLoop.getOperation())),
- [](scf::ForOp forOp) {
- return cast<LoopLikeOpInterface>(forOp.getOperation());
- });
- } else {
- nestedLoops = {innerMostLoop};
- }
-
- LoopLikeOpInterface outerMostLoop = nestedLoops.front();
+ LoopLikeOpInterface outerMostLoop = loops.front();
+ LoopLikeOpInterface innerMostLoop = loops.back();
// Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
@@ -2165,7 +2223,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
return success();
};
// 14. Add new inits to [nested] loops.
- if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
+ if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
newYieldValuesFn))) {
return rewriter.notifyMatchFailure(tiledConsumerOp,
"unable to add new inits to nest loop");
@@ -2174,9 +2232,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// 15. Replace the result of scf loop and consumer op with new loop's
// results.
- for (auto &&[oldResult, newResult] : llvm::zip(
- consumerOp->getResults(),
- nestedLoops.front()->getResults().take_back(newInits.size()))) {
+ for (auto &&[oldResult, newResult] :
+ llvm::zip(consumerOp->getResults(),
+ loops.front()->getResults().take_back(newInits.size()))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
|
|
@llvm/pr-subscribers-mlir-scf Author: None (MaheshRavishankar) ChangesThis gets the consumer fusion method in sync with the corresponding producer fusion method Some utilities are added to check that the loops passed in are perfectly nested (in the case of an This is change 1 of N to simplify the implementation of tile and fuse consumers. Full diff: https://github.com/llvm/llvm-project/pull/132082.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index d2cddfe00ac78..33a43ce2ee7bb 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -328,7 +328,8 @@ struct SCFFuseConsumerOfSliceResult {
SmallVector<Operation *> tiledOps;
};
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops);
/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index af87fb7a79d04..4fd10b0e30ab0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1890,25 +1890,81 @@ getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop) {
return {nestLoops.rbegin(), nestLoops.rend()};
}
+/// Check that the loop is perfectly nested.
+static bool
+isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected empty loop nest");
+ if (loops.size() == 1) {
+ return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
+ }
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
+ if (!outerFor || !innerFor) {
+ return false;
+ }
+ auto outerBBArgs = outerFor.getRegionIterArgs();
+ auto innerIterArgs = innerFor.getInitArgs();
+ if (outerBBArgs.size() != innerIterArgs.size()) {
+ return false;
+ }
+
+ for (auto [outerBBArg, innerIterArg] :
+ llvm::zip(outerBBArgs, innerIterArgs)) {
+ if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
+ innerIterArg != outerBBArg) {
+ return false;
+ }
+ }
+
+ auto outerYields =
+ cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
+ auto innerResults = innerFor.getResults();
+ if (outerYields.size() != innerResults.size()) {
+ return false;
+ }
+ for (auto [outerYield, innerResult] :
+ llvm::zip(outerYields, innerResults)) {
+ if (!llvm::hasSingleElement(innerResult.getUses()) ||
+ outerYield != innerResult) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
/// tensor.insert_slice. This function makes the following assumptions :
/// 1. tensor.insert_slice has scf.yield as its only user.
/// 2. scf.for's corresponding result has only one use.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
- tensor::InsertSliceOp candidateSliceOp) {
+ tensor::InsertSliceOp candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected loops to be empty");
+ // 1. Expect slice to be part of the body of the inner most loop.
+ Operation *containingOp = candidateSliceOp->getParentOp();
+ if (containingOp != loops.back()) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp,
+ "expected slice to be within body of inner-most loop");
+ }
+
+ if (!isPerfectlyNestedForLoops(loops)) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "expected passed loops to be perfectly nested.");
+ }
+
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
return failure();
Value sliceResult = candidateSliceOp.getResult();
// Step 1. Fetch the corresponding output.
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
unsigned resultNumber = yieldOpOperand.getOperandNumber();
- // Step 2. Check containing op is scf.for.
- Operation *containingOp = candidateSliceOp->getParentOp();
- auto forOp = dyn_cast<scf::ForOp>(containingOp);
- if (!forOp)
- return failure();
- scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
+
+ scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
}
@@ -1917,35 +1973,49 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
/// by a tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
- tensor::ParallelInsertSliceOp candidateSliceOp) {
- // Step 1. Fetch the corresponding output
+ tensor::ParallelInsertSliceOp candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected loops to be empty");
+ // 1. Check that the surrounding loop is a single scf.forall loop.
+ if (loops.size() != 1) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "expected single surrounding scf.forall");
+ }
+ auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
+ if (!forallOp) {
+ return rewriter.notifyMatchFailure(
+ candidateSliceOp, "expected single surrounding scf.forall");
+ }
+
+ // 2. Fetch the corresponding output
Value sliceDest = candidateSliceOp.getDest();
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
if (!iterArg)
return failure();
- Operation *containingOp = iterArg.getOwner()->getParentOp();
- if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
- return failure();
- // Step 2. Check that the containing op is scf.forall.
- auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
- if (!forallOp)
+ if (iterArg.getOwner()->getParentOp() != forallOp)
return failure();
+
unsigned resultNumber =
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
.getResultNumber();
- return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
+ return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
}
/// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
+getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ if (loops.empty()) {
+ return rewriter.notifyMatchFailure(sliceOp, "unexpected empty loops");
+ }
+
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(rewriter, insertSlice);
+ return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
- return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
+ return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
} else {
return failure();
}
@@ -1954,18 +2024,23 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
-mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
- Operation *candidateSliceOp) {
+mlir::scf::tileAndFuseConsumerOfSlice(
+ RewriterBase &rewriter, Operation *candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ // Return if `loops` is empty, return an error for now. Caller is expected
+ // to handle this case.
+ if (loops.empty()) {
+ return candidateSliceOp->emitOpError(
+ "cannot call tile and fuse consumer with an empty loop nest");
+ }
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))
return failure();
- bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
-
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand =
- getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
+ getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
@@ -1981,25 +2056,8 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
- // There are two possible cases regarding `oldLoopOp` here:
- // 1. single `scf.forall` or `scf.for`.
- // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
- // top-level loop is the outer-most one of these nested loops.
- LoopLikeOpInterface innerMostLoop =
- candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
- SmallVector<LoopLikeOpInterface> nestedLoops;
- if (isInsertSliceOp) {
- nestedLoops = llvm::map_to_vector(
- getPerfectlyNestedLoopsOutsideOf(
- cast<scf::ForOp>(innerMostLoop.getOperation())),
- [](scf::ForOp forOp) {
- return cast<LoopLikeOpInterface>(forOp.getOperation());
- });
- } else {
- nestedLoops = {innerMostLoop};
- }
-
- LoopLikeOpInterface outerMostLoop = nestedLoops.front();
+ LoopLikeOpInterface outerMostLoop = loops.front();
+ LoopLikeOpInterface innerMostLoop = loops.back();
// Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
@@ -2165,7 +2223,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
return success();
};
// 14. Add new inits to [nested] loops.
- if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
+ if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
newYieldValuesFn))) {
return rewriter.notifyMatchFailure(tiledConsumerOp,
"unable to add new inits to nest loop");
@@ -2174,9 +2232,9 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// 15. Replace the result of scf loop and consumer op with new loop's
// results.
- for (auto &&[oldResult, newResult] : llvm::zip(
- consumerOp->getResults(),
- nestedLoops.front()->getResults().take_back(newInits.size()))) {
+ for (auto &&[oldResult, newResult] :
+ llvm::zip(consumerOp->getResults(),
+ loops.front()->getResults().take_back(newInits.size()))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
|
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Abhishek-Varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice improvement! A few comments.
|
@Abhishek-Varma I had to remove one of the tests you added cause it doesnt really fit what is checked for as a perfectly nested loop. I'd rather drop that test if possible. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Abhishek-Varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks!
Signed-off-by: MaheshRavishankar <[email protected]>
59306c6 to
9c0d426
Compare
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Yun-Fly
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for making this logic clearer!
This gets the consumer fusion method in sync with the corresponding producer fusion method
tileAndFuseProducerOfSlice. Not taking this as input required use of complicated analysis to retrieve the surrounding loops which are very fragile. Just like the producer fusion method, the loops need to be taken in as an argument, with typically the loops being created by the tiling methods.Some utilities are added to check that the loops passed in are perfectly nested (in the case of an
scf.forloop nest.This is change 1 of N to simplify the implementation of tile and fuse consumers.