@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
10921092 for (auto [outerLoop, innerLoop] :
10931093 llvm::zip_equal (loops.drop_back (), loops.drop_front ())) {
10941094 // Again assume that all the outer loops are scf.for operations.
1095- auto outerForLoop = cast<scf::ForOp>(outerLoop);
1095+ auto outerForLoop = cast<scf::ForOp>(outerLoop. getOperation () );
10961096 auto outerLoopYield =
10971097 cast<scf::YieldOp>(outerForLoop.getBody ()->getTerminator ());
10981098 SmallVector<Value> newYields =
@@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
21842184 return clonedSlices;
21852185}
21862186
2187- // / Implementation of fusing consumer of a single slice by computing the
2188- // / slice of the consumer in-place for scf loop.
2189- FailureOr<scf::SCFFuseConsumerOfSliceResult>
2190- mlir::scf::tileAndFuseConsumerOfSlices (
2191- RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2192- MutableArrayRef<LoopLikeOpInterface> loops) {
2193- if (candidateSlices.empty ()) {
2194- return rewriter.notifyMatchFailure (
2195- rewriter.getUnknownLoc (),
2196- " no candidate slices provided for consumer fusion" );
2197- }
2198- // Return if `loops` is empty, return an error for now. Caller is expected
2199- // to handle this case.
2200- if (loops.empty ()) {
2201- return rewriter.notifyMatchFailure (
2202- candidateSlices.front (),
2203- " cannot call tile and fuse consumer with an empty loop nest" );
2204- }
2187+ static FailureOr<scf::SCFFuseConsumerOfSliceResult>
2188+ tileAndFuseConsumerOfSlicesImpl (RewriterBase &rewriter, Operation *consumerOp,
2189+ ArrayRef<OpOperand *> consumerOpOperands,
2190+ ArrayRef<Operation *> candidateSlices,
2191+ MutableArrayRef<LoopLikeOpInterface> loops) {
2192+ assert (!loops.empty () && " expected loops to be not empty" );
22052193
2206- if (!(llvm::all_of (candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2207- llvm::all_of (candidateSlices,
2208- llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2194+ // 1. Check assumption for loop with `reorderOperations` disabled.
2195+ if (failed (checkAssumptionForLoop (loops.front (), consumerOp, false ))) {
22092196 return rewriter.notifyMatchFailure (
2210- candidateSlices.front (),
2211- " candidates slices need to be all `tensor.extract_slice`s or "
2212- " `tensor.parallel_insert_slice`s" );
2213- }
2214-
2215- // 1. Get the consumer of scf.for for the result yielded by
2216- // tensor.insert_slice/parallel_insert_slice.
2217- SmallVector<OpOperand *> consumerOpOperands;
2218- Operation *consumerOp;
2219- {
2220- FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2221- getUntiledConsumerOperandsFromSlices (rewriter, candidateSlices, loops);
2222- if (failed (maybeConsumerOpOperand)) {
2223- return rewriter.notifyMatchFailure (candidateSlices.front (),
2224- " could not fetch consumer to fuse" );
2225- }
2226- std::swap (consumerOpOperands, maybeConsumerOpOperand.value ());
2227- consumerOp = consumerOpOperands.front ()->getOwner ();
2197+ loops.front (), " the first user of loop should not dominate any define "
2198+ " of consumer operand(s)" );
22282199 }
22292200
22302201 LoopLikeOpInterface outerMostLoop = loops.front ();
22312202 LoopLikeOpInterface innerMostLoop = loops.back ();
22322203
2233- // Check assumption for loop with `reorderOperations` disabled.
2234- if (failed (checkAssumptionForLoop (outerMostLoop, consumerOp, false ))) {
2235- return rewriter.notifyMatchFailure (
2236- outerMostLoop, " the first user of loop should not dominate any define "
2237- " of consumer operand(s)" );
2238- }
2239-
22402204 OpBuilder::InsertionGuard g (rewriter);
2241-
22422205 // 2. Check consumer is not using scf loop's output as init.
22432206 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
22442207 if (!dstOp)
@@ -2428,11 +2391,171 @@ mlir::scf::tileAndFuseConsumerOfSlices(
24282391 llvm::map_to_vector (operandNumbers, [&](unsigned operandNum) {
24292392 return &tileAndFuseResult->tiledOps [0 ]->getOpOperand (operandNum);
24302393 });
2394+ auto consumerOpOperandsVec = llvm::to_vector (consumerOpOperands);
24312395 return scf::SCFFuseConsumerOfSliceResult{
2432- std::move (consumerOpOperands ), std::move (tiledAndFusedOpOperands),
2396+ std::move (consumerOpOperandsVec ), std::move (tiledAndFusedOpOperands),
24332397 std::move (tileAndFuseResult->tiledOps )};
24342398}
24352399
2400+ // / Implementation of fusing consumer of a single slice by computing the
2401+ // / slice of the consumer in-place for scf loop.
2402+ FailureOr<scf::SCFFuseConsumerOfSliceResult>
2403+ mlir::scf::tileAndFuseConsumerOfSlices (
2404+ RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2405+ MutableArrayRef<LoopLikeOpInterface> loops) {
2406+ if (candidateSlices.empty ()) {
2407+ return rewriter.notifyMatchFailure (
2408+ rewriter.getUnknownLoc (),
2409+ " no candidate slices provided for consumer fusion" );
2410+ }
2411+ // Return if `loops` is empty, return an error for now. Caller is expected
2412+ // to handle this case.
2413+ if (loops.empty ()) {
2414+ return rewriter.notifyMatchFailure (
2415+ candidateSlices.front (),
2416+ " cannot call tile and fuse consumer with an empty loop nest" );
2417+ }
2418+
2419+ if (!(llvm::all_of (candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2420+ llvm::all_of (candidateSlices,
2421+ llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2422+ return rewriter.notifyMatchFailure (
2423+ candidateSlices.front (),
2424+ " candidates slices need to be all `tensor.extract_slice`s or "
2425+ " `tensor.parallel_insert_slice`s" );
2426+ }
2427+
2428+ // Get the consumer of scf.for for the result yielded by
2429+ // tensor.insert_slice/parallel_insert_slice.
2430+ SmallVector<OpOperand *> consumerOpOperands;
2431+ Operation *consumerOp;
2432+ {
2433+ FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2434+ getUntiledConsumerOperandsFromSlices (rewriter, candidateSlices, loops);
2435+ if (failed (maybeConsumerOpOperand)) {
2436+ return rewriter.notifyMatchFailure (candidateSlices.front (),
2437+ " could not fetch consumer to fuse" );
2438+ }
2439+ std::swap (consumerOpOperands, maybeConsumerOpOperand.value ());
2440+ consumerOp = consumerOpOperands.front ()->getOwner ();
2441+ }
2442+
2443+ return tileAndFuseConsumerOfSlicesImpl (
2444+ rewriter, consumerOp, consumerOpOperands, candidateSlices, loops);
2445+ }
2446+
2447+ // / For a given `result` of a `forallOp` return the
2448+ // / `tensor.parallel_insert_slice` op (or combining op) that is used to
2449+ // / construct this result.
2450+ static std::optional<Operation *>
2451+ getProducingParallelInsertSlice (scf::ForallOp forallOp, OpResult result) {
2452+ if (result.getOwner () != forallOp)
2453+ return std::nullopt ;
2454+ BlockArgument bbArg = forallOp.getTiedBlockArgument (result);
2455+ SmallVector<Operation *> combiningOps = forallOp.getCombiningOps (bbArg);
2456+ // If the number of combining ops is not 1, then this is unexpected. Return
2457+ // nullopt.
2458+ if (combiningOps.size () != 1 ) {
2459+ return std::nullopt ;
2460+ }
2461+ return combiningOps[0 ];
2462+ }
2463+
2464+ // / For a given result of the loop nest that is a tiled loop nest, return the
2465+ // / insert slice-like op that is used for consumer fusion
2466+ std::optional<Operation *>
2467+ getProducingInsertSliceLikeOp (OpResult result,
2468+ ArrayRef<LoopLikeOpInterface> loops) {
2469+ assert (!loops.empty () && " Expected loops to be not empty" );
2470+ LoopLikeOpInterface outermostLoop = loops.front ();
2471+
2472+ if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation ())) {
2473+ assert (loops.size () == 1 &&
2474+ " expected only a single loop when tiling using scf.forall" );
2475+ return getProducingParallelInsertSlice (forallOp, result);
2476+ }
2477+ // Assume that the loop nest is a nested `scf.for` that is created through
2478+ // tiling and retrieve the `tensor.insert_slice` operation used to construct
2479+ // the result.
2480+ while (loops.size () != 1 ) {
2481+ if (result.getOwner () != loops.front ())
2482+ return std::nullopt ;
2483+ auto forOp = dyn_cast<scf::ForOp>(loops.front ());
2484+ if (!forOp)
2485+ return std::nullopt ;
2486+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
2487+ OpResult innerForResult =
2488+ dyn_cast<OpResult>(yieldOp.getOperand (result.getResultNumber ()));
2489+ if (!innerForResult)
2490+ return std::nullopt ;
2491+ result = innerForResult;
2492+ loops = loops.drop_front ();
2493+ }
2494+ if (result.getOwner () != loops.front ())
2495+ return std::nullopt ;
2496+ auto forOp = dyn_cast<scf::ForOp>(loops.front ());
2497+ if (!forOp)
2498+ return std::nullopt ;
2499+ auto yieldOp = cast<scf::YieldOp>(forOp.getBody ()->getTerminator ());
2500+ auto insertSliceOp = yieldOp.getOperand (result.getResultNumber ())
2501+ .getDefiningOp <tensor::InsertSliceOp>();
2502+ if (!insertSliceOp)
2503+ return std::nullopt ;
2504+ return insertSliceOp;
2505+ }
2506+
2507+ FailureOr<scf::SCFFuseConsumerOfSliceResult>
2508+ mlir::scf::tileAndFuseConsumer (RewriterBase &rewriter, Operation *user,
2509+ MutableArrayRef<LoopLikeOpInterface> loops) {
2510+ // Only handle users that implement the `TilingInterface`.
2511+ if (!isa<TilingInterface>(user)) {
2512+ return rewriter.notifyMatchFailure (
2513+ user, " unhandled user that does not implement TilingInterface" );
2514+ }
2515+
2516+ // Return if `loops` is empty, return an error for now. Caller is expected
2517+ // to handle this case.
2518+ if (loops.empty ()) {
2519+ return rewriter.notifyMatchFailure (
2520+ user, " cannot call tile and fuse consumer with an empty loop nest" );
2521+ }
2522+
2523+ LoopLikeOpInterface outermostLoop = loops.front ();
2524+
2525+ // Collect the operands of the user that come from the outermost loop of the
2526+ // loop nest.
2527+ SmallVector<OpOperand *> consumerFusableOperands;
2528+ for (OpOperand &opOperand : user->getOpOperands ()) {
2529+ if (opOperand.get ().getDefiningOp () == outermostLoop) {
2530+ consumerFusableOperands.push_back (&opOperand);
2531+ }
2532+ }
2533+
2534+ // Nothing to fuse. Just return an empty set.
2535+ if (consumerFusableOperands.empty ()) {
2536+ return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
2537+ SmallVector<OpOperand *>{},
2538+ SmallVector<Operation *>{}};
2539+ }
2540+
2541+ // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
2542+ // for fusion.
2543+ SmallVector<Operation *> candidateSlices;
2544+ candidateSlices.reserve (consumerFusableOperands.size ());
2545+ for (OpOperand *opOperand : consumerFusableOperands) {
2546+ std::optional<Operation *> slice =
2547+ getProducingInsertSliceLikeOp (cast<OpResult>(opOperand->get ()), loops);
2548+ if (!slice) {
2549+ return rewriter.notifyMatchFailure (
2550+ user,
2551+ " couldnt find producing insert-slice like operation for operand" );
2552+ }
2553+ candidateSlices.push_back (slice.value ());
2554+ }
2555+ return tileAndFuseConsumerOfSlicesImpl (
2556+ rewriter, user, consumerFusableOperands, candidateSlices, loops);
2557+ }
2558+
24362559// ===----------------------------------------------------------------------===//
24372560// lowerToLoopsUsingSCFForOp implementation.
24382561// ===----------------------------------------------------------------------===//
0 commit comments