Skip to content

Commit 8d4a2d1

Browse files
Address comments.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent bf8c1de commit 8d4a2d1

File tree

3 files changed

+38
-39
lines changed

3 files changed

+38
-39
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2427,21 +2427,17 @@ mlir::scf::tileAndFuseConsumerOfSlices(
24272427

24282428
// Get the consumer of scf.for for the result yielded by
24292429
// tensor.insert_slice/parallel_insert_slice.
2430-
SmallVector<OpOperand *> consumerOpOperands;
2431-
Operation *consumerOp;
2432-
{
2433-
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2434-
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
2435-
if (failed(maybeConsumerOpOperand)) {
2436-
return rewriter.notifyMatchFailure(candidateSlices.front(),
2437-
"could not fetch consumer to fuse");
2438-
}
2439-
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
2440-
consumerOp = consumerOpOperands.front()->getOwner();
2430+
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
2431+
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
2432+
if (failed(maybeConsumerOpOperands)) {
2433+
return rewriter.notifyMatchFailure(candidateSlices.front(),
2434+
"could not fetch consumer to fuse");
24412435
}
2436+
Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
24422437

2443-
return tileAndFuseConsumerOfSlicesImpl(
2444-
rewriter, consumerOp, consumerOpOperands, candidateSlices, loops);
2438+
return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp,
2439+
maybeConsumerOpOperands.value(),
2440+
candidateSlices, loops);
24452441
}
24462442

24472443
/// For a given `result` of a `forallOp` return the
@@ -2455,21 +2451,19 @@ getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
24552451
SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
24562452
// If the number of combining ops is not 1, then this is unexpected. Return
24572453
// nullopt.
2458-
if (combiningOps.size() != 1) {
2454+
if (combiningOps.size() != 1)
24592455
return std::nullopt;
2460-
}
24612456
return combiningOps[0];
24622457
}
24632458

24642459
/// For a given result of the loop nest that is a tiled loop nest, return the
24652460
/// insert slice-like op that is used for consumer fusion
2466-
std::optional<Operation *>
2461+
static std::optional<Operation *>
24672462
getProducingInsertSliceLikeOp(OpResult result,
24682463
ArrayRef<LoopLikeOpInterface> loops) {
24692464
assert(!loops.empty() && "Expected loops to be not empty");
2470-
LoopLikeOpInterface outermostLoop = loops.front();
2471-
2472-
if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) {
2465+
LoopLikeOpInterface outerMostLoop = loops.front();
2466+
if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
24732467
assert(loops.size() == 1 &&
24742468
"expected only a single loop when tiling using scf.forall");
24752469
return getProducingParallelInsertSlice(forallOp, result);
@@ -2485,7 +2479,7 @@ getProducingInsertSliceLikeOp(OpResult result,
24852479
if (!forOp)
24862480
return std::nullopt;
24872481
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2488-
OpResult innerForResult =
2482+
auto innerForResult =
24892483
dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
24902484
if (!innerForResult)
24912485
return std::nullopt;
@@ -2507,27 +2501,26 @@ getProducingInsertSliceLikeOp(OpResult result,
25072501
}
25082502

25092503
FailureOr<scf::SCFFuseConsumerOfSliceResult>
2510-
mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user,
2504+
mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
25112505
MutableArrayRef<LoopLikeOpInterface> loops) {
2512-
// Only handle users that implement the `TilingInterface`.
2513-
if (!isa<TilingInterface>(user)) {
2506+
if (!isa<TilingInterface>(consumer)) {
25142507
return rewriter.notifyMatchFailure(
2515-
user, "unhandled user that does not implement TilingInterface");
2508+
consumer, "unhandled consumer that does not implement TilingInterface");
25162509
}
25172510

25182511
// Return if `loops` is empty, return an error for now. Caller is expected
25192512
// to handle this case.
25202513
if (loops.empty()) {
25212514
return rewriter.notifyMatchFailure(
2522-
user, "cannot call tile and fuse consumer with an empty loop nest");
2515+
consumer, "cannot call tile and fuse consumer with an empty loop nest");
25232516
}
25242517

25252518
LoopLikeOpInterface outermostLoop = loops.front();
25262519

2527-
// Collect the operands of the user that come from the outermost loop of the
2528-
// loop nest.
2520+
// Collect the operands of the consumer that come from the outermost loop of
2521+
// the loop nest.
25292522
SmallVector<OpOperand *> consumerFusableOperands;
2530-
for (OpOperand &opOperand : user->getOpOperands()) {
2523+
for (OpOperand &opOperand : consumer->getOpOperands()) {
25312524
if (opOperand.get().getDefiningOp() == outermostLoop) {
25322525
consumerFusableOperands.push_back(&opOperand);
25332526
}
@@ -2549,13 +2542,13 @@ mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user,
25492542
getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
25502543
if (!slice) {
25512544
return rewriter.notifyMatchFailure(
2552-
user,
2545+
consumer,
25532546
"couldnt find producing insert-slice like operation for operand");
25542547
}
25552548
candidateSlices.push_back(slice.value());
25562549
}
25572550
return tileAndFuseConsumerOfSlicesImpl(
2558-
rewriter, user, consumerFusableOperands, candidateSlices, loops);
2551+
rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
25592552
}
25602553

25612554
//===----------------------------------------------------------------------===//

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,20 +170,18 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
170170
// TestFuseConsumerOp
171171
//===----------------------------------------------------------------------===//
172172

173-
/// Apply fusing of consumer transformation to all payload ops and store both
174-
/// the original consumer operation as well as the fused consumer operation.
173+
/// Fuse the consumer and store both the original consumer operation as well as
174+
/// the fused consumer operation.
175175
static LogicalResult
176176
applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
177177
Operation *consumer,
178178
MutableArrayRef<LoopLikeOpInterface> loops,
179179
TransformResults &transformResults) {
180180
SmallVector<Operation *> fusedConsumerOps;
181-
182181
rewriter.setInsertionPoint(consumer);
183182

184183
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
185184
scf::tileAndFuseConsumer(rewriter, consumer, loops);
186-
187185
if (failed(fuseConsumerResults))
188186
return consumer->emitOpError("failed to fuse consumer of slice");
189187

@@ -192,7 +190,6 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
192190
fuseConsumerResults->tiledAndFusedConsumerOperands) {
193191
fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner());
194192
}
195-
196193
transformResults.set(transformOp->getOpResult(0), fusedConsumerOps);
197194
for (auto [index, loop] : llvm::enumerate(loops)) {
198195
transformResults.set(transformOp->getOpResult(index + 1), {loop});

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,13 @@ def TestFuseConsumerUsingSliceOp : Op<Transform_Dialect, "test.fuse_consumer_usi
5555
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
5656
ReportTrackingListenerFailuresOpTrait]> {
5757
let description = [{
58-
Fuses the consumer of the operation pointed to by the target handle
59-
using the options provided as attributes.
58+
For the `insert_slice`-like operations (that are typically generated through tiling),
59+
within the loop nests passed in as `loops` (that are typically generated through tiling),
60+
find the consumer that these slices map to (have to be the same consumer) and fuse
61+
the consumer into the loop.
62+
63+
Returns a handle to the original consumer operation and the consumer operation after
64+
fusion.
6065
}];
6166

6267
let arguments = (ins
@@ -78,8 +83,12 @@ def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
7883
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
7984
ReportTrackingListenerFailuresOpTrait]> {
8085
let description = [{
81-
Fuses the consumer of the operation pointed to by the target handle
82-
using the options provided as attributes.
86+
For the `consumer` that uses the result of the outer-most loop of a loop nest passed in
87+
as `loops` (that are typically generated through tiling), fuse the consumer into the
88+
loop.
89+
90+
Returns a handle to the consumer operation after fusion and the loops that might be
91+
modified.
8392
}];
8493

8594
let arguments = (ins

0 commit comments

Comments
 (0)