@@ -2427,21 +2427,17 @@ mlir::scf::tileAndFuseConsumerOfSlices(
24272427
24282428 // Get the consumer of scf.for for the result yielded by
24292429 // tensor.insert_slice/parallel_insert_slice.
2430- SmallVector<OpOperand *> consumerOpOperands;
2431- Operation *consumerOp;
2432- {
2433- FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2434- getUntiledConsumerOperandsFromSlices (rewriter, candidateSlices, loops);
2435- if (failed (maybeConsumerOpOperand)) {
2436- return rewriter.notifyMatchFailure (candidateSlices.front (),
2437- " could not fetch consumer to fuse" );
2438- }
2439- std::swap (consumerOpOperands, maybeConsumerOpOperand.value ());
2440- consumerOp = consumerOpOperands.front ()->getOwner ();
2430+ FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
2431+ getUntiledConsumerOperandsFromSlices (rewriter, candidateSlices, loops);
2432+ if (failed (maybeConsumerOpOperands)) {
2433+ return rewriter.notifyMatchFailure (candidateSlices.front (),
2434+ " could not fetch consumer to fuse" );
24412435 }
2436+ Operation *consumerOp = maybeConsumerOpOperands->front ()->getOwner ();
24422437
2443- return tileAndFuseConsumerOfSlicesImpl (
2444- rewriter, consumerOp, consumerOpOperands, candidateSlices, loops);
2438+ return tileAndFuseConsumerOfSlicesImpl (rewriter, consumerOp,
2439+ maybeConsumerOpOperands.value (),
2440+ candidateSlices, loops);
24452441}
24462442
24472443// / For a given `result` of a `forallOp` return the
@@ -2455,21 +2451,19 @@ getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
24552451 SmallVector<Operation *> combiningOps = forallOp.getCombiningOps (bbArg);
24562452 // If the number of combining ops is not 1, then this is unexpected. Return
24572453 // nullopt.
2458- if (combiningOps.size () != 1 ) {
2454+ if (combiningOps.size () != 1 )
24592455 return std::nullopt ;
2460- }
24612456 return combiningOps[0 ];
24622457}
24632458
24642459// / For a given result of the loop nest that is a tiled loop nest, return the
24652460// / insert slice-like op that is used for consumer fusion
2466- std::optional<Operation *>
2461+ static std::optional<Operation *>
24672462getProducingInsertSliceLikeOp (OpResult result,
24682463 ArrayRef<LoopLikeOpInterface> loops) {
24692464 assert (!loops.empty () && " Expected loops to be not empty" );
2470- LoopLikeOpInterface outermostLoop = loops.front ();
2471-
2472- if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation ())) {
2465+ LoopLikeOpInterface outerMostLoop = loops.front ();
2466+ if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation ())) {
24732467 assert (loops.size () == 1 &&
24742468 " expected only a single loop when tiling using scf.forall" );
24752469 return getProducingParallelInsertSlice (forallOp, result);
@@ -2485,7 +2479,7 @@ getProducingInsertSliceLikeOp(OpResult result,
24852479 if (!forOp)
24862480 return std::nullopt ;
24872481 auto yieldOp = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
2488- OpResult innerForResult =
2482+ auto innerForResult =
24892483 dyn_cast<OpResult>(yieldOp.getOperand (result.getResultNumber ()));
24902484 if (!innerForResult)
24912485 return std::nullopt ;
@@ -2507,27 +2501,26 @@ getProducingInsertSliceLikeOp(OpResult result,
25072501}
25082502
25092503FailureOr<scf::SCFFuseConsumerOfSliceResult>
2510- mlir::scf::tileAndFuseConsumer (RewriterBase &rewriter, Operation *user ,
2504+ mlir::scf::tileAndFuseConsumer (RewriterBase &rewriter, Operation *consumer ,
25112505 MutableArrayRef<LoopLikeOpInterface> loops) {
2512- // Only handle users that implement the `TilingInterface`.
2513- if (!isa<TilingInterface>(user)) {
2506+ if (!isa<TilingInterface>(consumer)) {
25142507 return rewriter.notifyMatchFailure (
2515- user , " unhandled user that does not implement TilingInterface" );
2508+ consumer , " unhandled consumer that does not implement TilingInterface" );
25162509 }
25172510
25182511 // Return if `loops` is empty, return an error for now. Caller is expected
25192512 // to handle this case.
25202513 if (loops.empty ()) {
25212514 return rewriter.notifyMatchFailure (
2522- user , " cannot call tile and fuse consumer with an empty loop nest" );
2515+ consumer , " cannot call tile and fuse consumer with an empty loop nest" );
25232516 }
25242517
25252518 LoopLikeOpInterface outermostLoop = loops.front ();
25262519
2527- // Collect the operands of the user that come from the outermost loop of the
2528- // loop nest.
2520+ // Collect the operands of the consumer that come from the outermost loop of
2521+ // the loop nest.
25292522 SmallVector<OpOperand *> consumerFusableOperands;
2530- for (OpOperand &opOperand : user ->getOpOperands ()) {
2523+ for (OpOperand &opOperand : consumer ->getOpOperands ()) {
25312524 if (opOperand.get ().getDefiningOp () == outermostLoop) {
25322525 consumerFusableOperands.push_back (&opOperand);
25332526 }
@@ -2549,13 +2542,13 @@ mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user,
25492542 getProducingInsertSliceLikeOp (cast<OpResult>(opOperand->get ()), loops);
25502543 if (!slice) {
25512544 return rewriter.notifyMatchFailure (
2552- user ,
2545+ consumer ,
25532546 " couldnt find producing insert-slice like operation for operand" );
25542547 }
25552548 candidateSlices.push_back (slice.value ());
25562549 }
25572550 return tileAndFuseConsumerOfSlicesImpl (
2558- rewriter, user , consumerFusableOperands, candidateSlices, loops);
2551+ rewriter, consumer , consumerFusableOperands, candidateSlices, loops);
25592552}
25602553
25612554// ===----------------------------------------------------------------------===//
0 commit comments