@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
10921092 for (auto [outerLoop, innerLoop] :
10931093 llvm::zip_equal (loops.drop_back (), loops.drop_front ())) {
10941094 // Again assume that all the outer loops are scf.for operations.
1095- auto outerForLoop = cast<scf::ForOp>(outerLoop);
1095+ auto outerForLoop = cast<scf::ForOp>(outerLoop. getOperation () );
10961096 auto outerLoopYield =
10971097 cast<scf::YieldOp>(outerForLoop.getBody ()->getTerminator ());
10981098 SmallVector<Value> newYields =
@@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
21842184 return clonedSlices;
21852185}
21862186
2187- // / Implementation of fusing consumer of a single slice by computing the
2188- // / slice of the consumer in-place for scf loop.
2189- FailureOr<scf::SCFFuseConsumerOfSliceResult>
2190- mlir::scf::tileAndFuseConsumerOfSlices (
2191- RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2192- MutableArrayRef<LoopLikeOpInterface> loops) {
2193- if (candidateSlices.empty ()) {
2194- return rewriter.notifyMatchFailure (
2195- rewriter.getUnknownLoc (),
2196- " no candidate slices provided for consumer fusion" );
2197- }
2198- // Return if `loops` is empty, return an error for now. Caller is expected
2199- // to handle this case.
2200- if (loops.empty ()) {
2201- return rewriter.notifyMatchFailure (
2202- candidateSlices.front (),
2203- " cannot call tile and fuse consumer with an empty loop nest" );
2204- }
2187+ static FailureOr<scf::SCFFuseConsumerOfSliceResult>
2188+ tileAndFuseConsumerOfSlicesImpl (RewriterBase &rewriter, Operation *consumerOp,
2189+ ArrayRef<OpOperand *> consumerOpOperands,
2190+ ArrayRef<Operation *> candidateSlices,
2191+ MutableArrayRef<LoopLikeOpInterface> loops) {
2192+ assert (!loops.empty () && " expected loops to be not empty" );
22052193
2206- if (!(llvm::all_of (candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2207- llvm::all_of (candidateSlices,
2208- llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2194+ // 1. Check assumption for loop with `reorderOperations` disabled.
2195+ if (failed (checkAssumptionForLoop (loops.front (), consumerOp, false ))) {
22092196 return rewriter.notifyMatchFailure (
2210- candidateSlices.front (),
2211- " candidates slices need to be all `tensor.extract_slice`s or "
2212- " `tensor.parallel_insert_slice`s" );
2213- }
2214-
2215- // 1. Get the consumer of scf.for for the result yielded by
2216- // tensor.insert_slice/parallel_insert_slice.
2217- SmallVector<OpOperand *> consumerOpOperands;
2218- Operation *consumerOp;
2219- {
2220- FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2221- getUntiledConsumerOperandsFromSlices (rewriter, candidateSlices, loops);
2222- if (failed (maybeConsumerOpOperand)) {
2223- return rewriter.notifyMatchFailure (candidateSlices.front (),
2224- " could not fetch consumer to fuse" );
2225- }
2226- std::swap (consumerOpOperands, maybeConsumerOpOperand.value ());
2227- consumerOp = consumerOpOperands.front ()->getOwner ();
2197+ loops.front (), " the first user of loop should not dominate any define "
2198+ " of consumer operand(s)" );
22282199 }
22292200
22302201 LoopLikeOpInterface outerMostLoop = loops.front ();
22312202 LoopLikeOpInterface innerMostLoop = loops.back ();
22322203
2233- // Check assumption for loop with `reorderOperations` disabled.
2234- if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp, false ))) {
2235- return rewriter.notifyMatchFailure (
2236- outerMostLoop, " the first user of loop should not dominate any define "
2237- " of consumer operand(s)" );
2238- }
2239-
22402204 OpBuilder::InsertionGuard g (rewriter);
2241-
22422205 // 2. Check consumer is not using scf loop's output as init.
22432206 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
22442207 if (!dstOp)
@@ -2428,11 +2391,166 @@ mlir::scf::tileAndFuseConsumerOfSlices(
24282391 llvm::map_to_vector (operandNumbers, [&](unsigned operandNum) {
24292392 return &tileAndFuseResult->tiledOps [0 ]->getOpOperand (operandNum);
24302393 });
2394+ auto consumerOpOperandsVec = llvm::to_vector (consumerOpOperands);
24312395 return scf::SCFFuseConsumerOfSliceResult{
2432- std::move (consumerOpOperands ), std::move (tiledAndFusedOpOperands),
2396+ std::move (consumerOpOperandsVec ), std::move (tiledAndFusedOpOperands),
24332397 std::move (tileAndFuseResult->tiledOps )};
24342398}
24352399
2400+ // / Implementation of fusing consumer of a single slice by computing the
2401+ // / slice of the consumer in-place for scf loop.
2402+ FailureOr<scf::SCFFuseConsumerOfSliceResult>
2403+ mlir::scf::tileAndFuseConsumerOfSlices (
2404+ RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2405+ MutableArrayRef<LoopLikeOpInterface> loops) {
2406+ if (candidateSlices.empty ()) {
2407+ return rewriter.notifyMatchFailure (
2408+ rewriter.getUnknownLoc (),
2409+ " no candidate slices provided for consumer fusion" );
2410+ }
2411+ // Return if `loops` is empty, return an error for now. Caller is expected
2412+ // to handle this case.
2413+ if (loops.empty ()) {
2414+ return rewriter.notifyMatchFailure (
2415+ candidateSlices.front (),
2416+ " cannot call tile and fuse consumer with an empty loop nest" );
2417+ }
2418+
2419+ if (!(llvm::all_of (candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2420+ llvm::all_of (candidateSlices,
2421+ llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2422+ return rewriter.notifyMatchFailure (
2423+ candidateSlices.front (),
2424+ " candidates slices need to be all `tensor.extract_slice`s or "
2425+ " `tensor.parallel_insert_slice`s" );
2426+ }
2427+
2428+ // Get the consumer of scf.for for the result yielded by
2429+ // tensor.insert_slice/parallel_insert_slice.
2430+ FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
2431+ getUntiledConsumerOperandsFromSlices (rewriter, candidateSlices, loops);
2432+ if (failed (maybeConsumerOpOperands)) {
2433+ return rewriter.notifyMatchFailure (candidateSlices.front (),
2434+ " could not fetch consumer to fuse" );
2435+ }
2436+ Operation *consumerOp = maybeConsumerOpOperands->front ()->getOwner ();
2437+
2438+ return tileAndFuseConsumerOfSlicesImpl (rewriter, consumerOp,
2439+ maybeConsumerOpOperands.value (),
2440+ candidateSlices, loops);
2441+ }
2442+
2443+ // / For a given `result` of a `forallOp` return the
2444+ // / `tensor.parallel_insert_slice` op (or combining op) that is used to
2445+ // / construct this result.
2446+ static std::optional<Operation *>
2447+ getProducingParallelInsertSlice (scf::ForallOp forallOp, OpResult result) {
2448+ if (result.getOwner () != forallOp)
2449+ return std::nullopt ;
2450+ BlockArgument bbArg = forallOp.getTiedBlockArgument (result);
2451+ SmallVector<Operation *> combiningOps = forallOp.getCombiningOps (bbArg);
2452+ // If the number of combining ops is not 1, then this is unexpected. Return
2453+ // nullopt.
2454+ if (combiningOps.size () != 1 )
2455+ return std::nullopt ;
2456+ return combiningOps[0 ];
2457+ }
2458+
2459+ // / For a given result of the loop nest that is a tiled loop nest, return the
2460+ // / insert slice-like op that is used for consumer fusion
2461+ static std::optional<Operation *>
2462+ getProducingInsertSliceLikeOp (OpResult result,
2463+ ArrayRef<LoopLikeOpInterface> loops) {
2464+ assert (!loops.empty () && " Expected loops to be not empty" );
2465+ LoopLikeOpInterface outerMostLoop = loops.front ();
2466+ if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation ())) {
2467+ assert (loops.size () == 1 &&
2468+ " expected only a single loop when tiling using scf.forall" );
2469+ return getProducingParallelInsertSlice (forallOp, result);
2470+ }
2471+ // Assume that the loop nest is a nested `scf.for` that is created through
2472+ // tiling and retrieve the `tensor.insert_slice` operation used to construct
2473+ // the result.
2474+ while (loops.size () != 1 ) {
2475+ LoopLikeOpInterface loop = loops.front ();
2476+ if (result.getOwner () != loop)
2477+ return std::nullopt ;
2478+ auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ());
2479+ if (!forOp)
2480+ return std::nullopt ;
2481+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
2482+ auto innerForResult =
2483+ dyn_cast<OpResult>(yieldOp.getOperand (result.getResultNumber ()));
2484+ if (!innerForResult)
2485+ return std::nullopt ;
2486+ result = innerForResult;
2487+ loops = loops.drop_front ();
2488+ }
2489+ LoopLikeOpInterface loop = loops.front ();
2490+ if (result.getOwner () != loop)
2491+ return std::nullopt ;
2492+ auto forOp = dyn_cast<scf::ForOp>(loop.getOperation ());
2493+ if (!forOp)
2494+ return std::nullopt ;
2495+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
2496+ auto insertSliceOp = yieldOp.getOperand (result.getResultNumber ())
2497+ .getDefiningOp <tensor::InsertSliceOp>();
2498+ if (!insertSliceOp)
2499+ return std::nullopt ;
2500+ return insertSliceOp;
2501+ }
2502+
2503+ FailureOr<scf::SCFFuseConsumerOfSliceResult>
2504+ mlir::scf::tileAndFuseConsumer (RewriterBase &rewriter, Operation *consumer,
2505+ MutableArrayRef<LoopLikeOpInterface> loops) {
2506+ if (!isa<TilingInterface>(consumer)) {
2507+ return rewriter.notifyMatchFailure (
2508+ consumer, " unhandled consumer that does not implement TilingInterface" );
2509+ }
2510+
2511+ // Return if `loops` is empty, return an error for now. Caller is expected
2512+ // to handle this case.
2513+ if (loops.empty ()) {
2514+ return rewriter.notifyMatchFailure (
2515+ consumer, " cannot call tile and fuse consumer with an empty loop nest" );
2516+ }
2517+
2518+ LoopLikeOpInterface outermostLoop = loops.front ();
2519+
2520+ // Collect the operands of the consumer that come from the outermost loop of
2521+ // the loop nest.
2522+ SmallVector<OpOperand *> consumerFusableOperands;
2523+ for (OpOperand &opOperand : consumer->getOpOperands ()) {
2524+ if (opOperand.get ().getDefiningOp () == outermostLoop) {
2525+ consumerFusableOperands.push_back (&opOperand);
2526+ }
2527+ }
2528+
2529+ // Nothing to fuse. Just return an empty set.
2530+ if (consumerFusableOperands.empty ()) {
2531+ return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
2532+ SmallVector<OpOperand *>{},
2533+ SmallVector<Operation *>{}};
2534+ }
2535+
2536+ // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
2537+ // for fusion.
2538+ SmallVector<Operation *> candidateSlices;
2539+ candidateSlices.reserve (consumerFusableOperands.size ());
2540+ for (OpOperand *opOperand : consumerFusableOperands) {
2541+ std::optional<Operation *> slice =
2542+ getProducingInsertSliceLikeOp (cast<OpResult>(opOperand->get ()), loops);
2543+ if (!slice) {
2544+ return rewriter.notifyMatchFailure (
2545+ consumer,
2546+ " couldnt find producing insert-slice like operation for operand" );
2547+ }
2548+ candidateSlices.push_back (slice.value ());
2549+ }
2550+ return tileAndFuseConsumerOfSlicesImpl (
2551+ rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
2552+ }
2553+
24362554// ===----------------------------------------------------------------------===//
24372555// lowerToLoopsUsingSCFForOp implementation.
24382556// ===----------------------------------------------------------------------===//
0 commit comments