Skip to content

Commit dbeda4f

Browse files
[mlir][SCF] Add scf::tileAndFuseConsumer that tiles a consumer into a given tiled loop nest. (#167634)
The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the `TilingInterface`. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as `loops` (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using `TilingInterface`). This handles more naturally the case where multiple operands of the consumer come from the loop nest. The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of `scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices (`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the `scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of `scf::tileAndFuseConsumer`. There is a lot of tech-debt that has accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to `scf::tileAndFuseConsumer`, the old path is still maintained. The test for `scf::tileAndFuseConsumerUsingSlices` is copied to `tile-and-fuse-consumer.mlir` to `tile-and-fuse-consumer-using-slices.mlir`. All the tests that were there in this file are now using the `tileAndFuseConsumer` method. The test op `test.tile_and_fuse_consumer` is modified to call `scf::tileAndFuseConsumer`, while a new op `test.tile_and_fuse_consumer_of_slice` is used to keep the old path tested while it is deprecated. --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 1136239 commit dbeda4f

File tree

8 files changed

+1634
-253
lines changed

8 files changed

+1634
-253
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [
613613
getNumDynamicControlOperands() + getRank());
614614
}
615615

616+
BlockArgument getTiedBlockArgument(OpResult opResult) {
617+
assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult");
618+
return getBody()->getArgument(getRank() + opResult.getResultNumber());
619+
}
620+
616621
::mlir::Value getInductionVar(int64_t idx) {
617622
return getInductionVars()[idx];
618623
}

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
415415
/// tiled in a manner that is consistent for all the passed slices. Note that
416416
/// the method replaces the uses of `candidateSlices` with the tiled and fused
417417
/// consumer value but does not delete the slice operations.
418+
/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion
419+
/// is to take the consumer operation, and find the slices to use for fusion
420+
/// by walking its operands to the `loops` and then into the body to get the
421+
/// slices used for fusion.
418422
struct SCFFuseConsumerOfSliceResult {
419423
// Original untiled consumer operands.
420424
SmallVector<OpOperand *> origConsumerOperands;
@@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
427431
ArrayRef<Operation *> candidateSlices,
428432
MutableArrayRef<LoopLikeOpInterface> loops);
429433

434+
/// Fuse the `consumer` operation into the loop nest provided by `loops`.
435+
/// The transformation looks for operands in the `consumer` that are defined
436+
/// by the outermost loop of the loop nest in `loops`. The nested loop is
437+
/// expected to have the structure of the loops generated through tiling.
438+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
439+
tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
440+
MutableArrayRef<LoopLikeOpInterface> loops);
441+
430442
/// Method to lower an `op` that implements the `TilingInterface` to
431443
/// loops/scalars.
432444
FailureOr<SmallVector<scf::ForOp>>

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 167 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
10921092
for (auto [outerLoop, innerLoop] :
10931093
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
10941094
// Again assume that all the outer loops are scf.for operations.
1095-
auto outerForLoop = cast<scf::ForOp>(outerLoop);
1095+
auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
10961096
auto outerLoopYield =
10971097
cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
10981098
SmallVector<Value> newYields =
@@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
21842184
return clonedSlices;
21852185
}
21862186

2187-
/// Implementation of fusing consumer of a single slice by computing the
2188-
/// slice of the consumer in-place for scf loop.
2189-
FailureOr<scf::SCFFuseConsumerOfSliceResult>
2190-
mlir::scf::tileAndFuseConsumerOfSlices(
2191-
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2192-
MutableArrayRef<LoopLikeOpInterface> loops) {
2193-
if (candidateSlices.empty()) {
2194-
return rewriter.notifyMatchFailure(
2195-
rewriter.getUnknownLoc(),
2196-
"no candidate slices provided for consumer fusion");
2197-
}
2198-
// Return if `loops` is empty, return an error for now. Caller is expected
2199-
// to handle this case.
2200-
if (loops.empty()) {
2201-
return rewriter.notifyMatchFailure(
2202-
candidateSlices.front(),
2203-
"cannot call tile and fuse consumer with an empty loop nest");
2204-
}
2187+
static FailureOr<scf::SCFFuseConsumerOfSliceResult>
2188+
tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp,
2189+
ArrayRef<OpOperand *> consumerOpOperands,
2190+
ArrayRef<Operation *> candidateSlices,
2191+
MutableArrayRef<LoopLikeOpInterface> loops) {
2192+
assert(!loops.empty() && "expected loops to be not empty");
22052193

2206-
if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2207-
llvm::all_of(candidateSlices,
2208-
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2194+
// 1. Check assumption for loop with `reorderOperations` disabled.
2195+
if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
22092196
return rewriter.notifyMatchFailure(
2210-
candidateSlices.front(),
2211-
"candidates slices need to be all `tensor.extract_slice`s or "
2212-
"`tensor.parallel_insert_slice`s");
2213-
}
2214-
2215-
// 1. Get the consumer of scf.for for the result yielded by
2216-
// tensor.insert_slice/parallel_insert_slice.
2217-
SmallVector<OpOperand *> consumerOpOperands;
2218-
Operation *consumerOp;
2219-
{
2220-
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2221-
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
2222-
if (failed(maybeConsumerOpOperand)) {
2223-
return rewriter.notifyMatchFailure(candidateSlices.front(),
2224-
"could not fetch consumer to fuse");
2225-
}
2226-
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
2227-
consumerOp = consumerOpOperands.front()->getOwner();
2197+
loops.front(), "the first user of loop should not dominate any define "
2198+
"of consumer operand(s)");
22282199
}
22292200

22302201
LoopLikeOpInterface outerMostLoop = loops.front();
22312202
LoopLikeOpInterface innerMostLoop = loops.back();
22322203

2233-
// Check assumption for loop with `reorderOperations` disabled.
2234-
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
2235-
return rewriter.notifyMatchFailure(
2236-
outerMostLoop, "the first user of loop should not dominate any define "
2237-
"of consumer operand(s)");
2238-
}
2239-
22402204
OpBuilder::InsertionGuard g(rewriter);
2241-
22422205
// 2. Check consumer is not using scf loop's output as init.
22432206
auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
22442207
if (!dstOp)
@@ -2428,11 +2391,166 @@ mlir::scf::tileAndFuseConsumerOfSlices(
24282391
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
24292392
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
24302393
});
2394+
auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
24312395
return scf::SCFFuseConsumerOfSliceResult{
2432-
std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
2396+
std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
24332397
std::move(tileAndFuseResult->tiledOps)};
24342398
}
24352399

2400+
/// Implementation of fusing consumer of a single slice by computing the
2401+
/// slice of the consumer in-place for scf loop.
2402+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
2403+
mlir::scf::tileAndFuseConsumerOfSlices(
2404+
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2405+
MutableArrayRef<LoopLikeOpInterface> loops) {
2406+
if (candidateSlices.empty()) {
2407+
return rewriter.notifyMatchFailure(
2408+
rewriter.getUnknownLoc(),
2409+
"no candidate slices provided for consumer fusion");
2410+
}
2411+
// Return if `loops` is empty, return an error for now. Caller is expected
2412+
// to handle this case.
2413+
if (loops.empty()) {
2414+
return rewriter.notifyMatchFailure(
2415+
candidateSlices.front(),
2416+
"cannot call tile and fuse consumer with an empty loop nest");
2417+
}
2418+
2419+
if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2420+
llvm::all_of(candidateSlices,
2421+
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2422+
return rewriter.notifyMatchFailure(
2423+
candidateSlices.front(),
2424+
"candidates slices need to be all `tensor.extract_slice`s or "
2425+
"`tensor.parallel_insert_slice`s");
2426+
}
2427+
2428+
// Get the consumer of scf.for for the result yielded by
2429+
// tensor.insert_slice/parallel_insert_slice.
2430+
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
2431+
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
2432+
if (failed(maybeConsumerOpOperands)) {
2433+
return rewriter.notifyMatchFailure(candidateSlices.front(),
2434+
"could not fetch consumer to fuse");
2435+
}
2436+
Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
2437+
2438+
return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp,
2439+
maybeConsumerOpOperands.value(),
2440+
candidateSlices, loops);
2441+
}
2442+
2443+
/// For a given `result` of a `forallOp` return the
2444+
/// `tensor.parallel_insert_slice` op (or combining op) that is used to
2445+
/// construct this result.
2446+
static std::optional<Operation *>
2447+
getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
2448+
if (result.getOwner() != forallOp)
2449+
return std::nullopt;
2450+
BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
2451+
SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
2452+
// If the number of combining ops is not 1, then this is unexpected. Return
2453+
// nullopt.
2454+
if (combiningOps.size() != 1)
2455+
return std::nullopt;
2456+
return combiningOps[0];
2457+
}
2458+
2459+
/// For a given result of the loop nest that is a tiled loop nest, return the
2460+
/// insert slice-like op that is used for consumer fusion
2461+
static std::optional<Operation *>
2462+
getProducingInsertSliceLikeOp(OpResult result,
2463+
ArrayRef<LoopLikeOpInterface> loops) {
2464+
assert(!loops.empty() && "Expected loops to be not empty");
2465+
LoopLikeOpInterface outerMostLoop = loops.front();
2466+
if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
2467+
assert(loops.size() == 1 &&
2468+
"expected only a single loop when tiling using scf.forall");
2469+
return getProducingParallelInsertSlice(forallOp, result);
2470+
}
2471+
// Assume that the loop nest is a nested `scf.for` that is created through
2472+
// tiling and retrieve the `tensor.insert_slice` operation used to construct
2473+
// the result.
2474+
while (loops.size() != 1) {
2475+
LoopLikeOpInterface loop = loops.front();
2476+
if (result.getOwner() != loop)
2477+
return std::nullopt;
2478+
auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2479+
if (!forOp)
2480+
return std::nullopt;
2481+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2482+
auto innerForResult =
2483+
dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
2484+
if (!innerForResult)
2485+
return std::nullopt;
2486+
result = innerForResult;
2487+
loops = loops.drop_front();
2488+
}
2489+
LoopLikeOpInterface loop = loops.front();
2490+
if (result.getOwner() != loop)
2491+
return std::nullopt;
2492+
auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2493+
if (!forOp)
2494+
return std::nullopt;
2495+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2496+
auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
2497+
.getDefiningOp<tensor::InsertSliceOp>();
2498+
if (!insertSliceOp)
2499+
return std::nullopt;
2500+
return insertSliceOp;
2501+
}
2502+
2503+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
2504+
mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
2505+
MutableArrayRef<LoopLikeOpInterface> loops) {
2506+
if (!isa<TilingInterface>(consumer)) {
2507+
return rewriter.notifyMatchFailure(
2508+
consumer, "unhandled consumer that does not implement TilingInterface");
2509+
}
2510+
2511+
// Return if `loops` is empty, return an error for now. Caller is expected
2512+
// to handle this case.
2513+
if (loops.empty()) {
2514+
return rewriter.notifyMatchFailure(
2515+
consumer, "cannot call tile and fuse consumer with an empty loop nest");
2516+
}
2517+
2518+
LoopLikeOpInterface outermostLoop = loops.front();
2519+
2520+
// Collect the operands of the consumer that come from the outermost loop of
2521+
// the loop nest.
2522+
SmallVector<OpOperand *> consumerFusableOperands;
2523+
for (OpOperand &opOperand : consumer->getOpOperands()) {
2524+
if (opOperand.get().getDefiningOp() == outermostLoop) {
2525+
consumerFusableOperands.push_back(&opOperand);
2526+
}
2527+
}
2528+
2529+
// Nothing to fuse. Just return an empty set.
2530+
if (consumerFusableOperands.empty()) {
2531+
return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
2532+
SmallVector<OpOperand *>{},
2533+
SmallVector<Operation *>{}};
2534+
}
2535+
2536+
// Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
2537+
// for fusion.
2538+
SmallVector<Operation *> candidateSlices;
2539+
candidateSlices.reserve(consumerFusableOperands.size());
2540+
for (OpOperand *opOperand : consumerFusableOperands) {
2541+
std::optional<Operation *> slice =
2542+
getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
2543+
if (!slice) {
2544+
return rewriter.notifyMatchFailure(
2545+
consumer,
2546+
"couldnt find producing insert-slice like operation for operand");
2547+
}
2548+
candidateSlices.push_back(slice.value());
2549+
}
2550+
return tileAndFuseConsumerOfSlicesImpl(
2551+
rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
2552+
}
2553+
24362554
//===----------------------------------------------------------------------===//
24372555
// lowerToLoopsUsingSCFForOp implementation.
24382556
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ module {
170170
// Fuse the consumer operation into the tiled loop.
171171
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
172172
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
173-
transform.test.fuse_consumer %slice_op in (%forall_op)
173+
transform.test.fuse_consumer_using_slice %slice_op in (%forall_op)
174174
: (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
175175
transform.yield
176176
}
@@ -231,7 +231,7 @@ module {
231231
// Fuse the consumer operation into the tiled loop.
232232
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
233233
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
234-
// Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
234+
// Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice
235235
// is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
236236
// to fuse" error.
237237
transform.yield

0 commit comments

Comments
 (0)