@@ -237,8 +237,8 @@ static bool areAllStaticLoopBounds(scf::ForallOp forallOp) {
237237
238238// / Find dimensions of the loop that are unit-trip count and drop them from the
239239// / distributed dimensions.
240- static LogicalResult dropUnitDistributedDims (RewriterBase &rewriter,
241- scf::ForallOp forallOp) {
240+ static FailureOr<scf::ForallOp>
241+ dropUnitDistributedDims (RewriterBase &rewriter, scf::ForallOp forallOp) {
242242 SmallVector<OpFoldResult> mixedLbs = forallOp.getMixedLowerBound ();
243243 SmallVector<OpFoldResult> mixedUbs = forallOp.getMixedUpperBound ();
244244 SmallVector<OpFoldResult> mixedSteps = forallOp.getMixedStep ();
@@ -261,7 +261,7 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
261261 }
262262 }
263263 if (droppedLoops.empty ()) {
264- return success () ;
264+ return forallOp ;
265265 }
266266
267267 OpBuilder::InsertionGuard g (rewriter);
@@ -303,7 +303,7 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
303303 rewriter.mergeBlocks (oldLoopBody, newLoopBody, argReplacements);
304304
305305 rewriter.replaceOp (forallOp, newForallOp.getResults ());
306- return success () ;
306+ return newForallOp ;
307307}
308308
309309// ===---------------------------------------------------------------------===//
@@ -314,8 +314,9 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
314314// Returns a list of new `tensor.extract_slice` ops with new fusion
315315// opportunities, as well as the new surrounding `scf.forall` (because consumer
316316// fusion replaces the loop).
317- static std::pair<std::queue<Operation *>, scf::ForallOp>
318- fuseConsumers (RewriterBase &rewriter, Operation *tiledOp) {
317+ static std::queue<Operation *>
318+ fuseConsumers (RewriterBase &rewriter, Operation *tiledOp,
319+ MutableArrayRef<LoopLikeOpInterface> loops) {
319320 auto addCandidateSlices =
320321 [](Operation *fusedOp,
321322 std::queue<tensor::ParallelInsertSliceOp> &candidates) {
@@ -333,15 +334,15 @@ fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
333334 addCandidateSlices (tiledOp, candidates);
334335
335336 std::queue<Operation *> newFusionOpportunities;
336- scf::ForallOp newLoop = tiledOp->getParentOfType <scf::ForallOp>();
337337 while (!candidates.empty ()) {
338338
339339 // Traverse the slices in BFS fashion.
340340 tensor::ParallelInsertSliceOp candidateSliceOp = candidates.front ();
341341 candidates.pop ();
342342
343343 FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
344- mlir::scf::tileAndFuseConsumerOfSlice (rewriter, candidateSliceOp);
344+ mlir::scf::tileAndFuseConsumerOfSlice (rewriter, candidateSliceOp,
345+ loops);
345346 if (failed (fusedResult)) {
346347 LLVM_DEBUG (llvm::dbgs () << " failed to fuse consumer of slice: "
347348 << candidateSliceOp << " \n " );
@@ -369,19 +370,15 @@ fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
369370 }
370371 }
371372 }
372- // Store the new loop for follow up producer fusion.
373- newLoop = tiledOp->getParentOfType <scf::ForallOp>();
374373 }
375374 }
376- return std::make_pair ( newFusionOpportunities, newLoop) ;
375+ return newFusionOpportunities;
377376}
378377
379378static void fuseProducersOfSlices (RewriterBase &rewriter,
380379 std::queue<Operation *> &worklist,
381380 scf::SCFTileAndFuseOptions &options,
382- scf::ForallOp forallOp) {
383- SmallVector<LoopLikeOpInterface> loops = {
384- cast<LoopLikeOpInterface>(&*forallOp)};
381+ MutableArrayRef<LoopLikeOpInterface> loops) {
385382 while (!worklist.empty ()) {
386383 auto candidateSlice = cast<tensor::ExtractSliceOp>(worklist.front ());
387384 worklist.pop ();
@@ -532,7 +529,6 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
532529
533530 // If the `tilableOp` is a `memref` op, then just tile the operation.
534531 SmallVector<LoopLikeOpInterface> tilingLoops;
535- Operation *rootTiledOp = nullptr ;
536532 if (tilableOp->getNumResults () == 0 ) {
537533 FailureOr<scf::SCFTilingResult> tilingResult =
538534 scf::tileUsingSCF (rewriter, tilableOp, tilingOptions);
@@ -554,7 +550,16 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
554550 rewriter.replaceAllUsesWith (origValue, replacement);
555551 }
556552 std::swap (tileAndFuseResult->loops , tilingLoops);
557- rootTiledOp = tileAndFuseResult->tiledAndFusedOps .front ();
553+ Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps .front ();
554+ auto newFusionOpportunities =
555+ fuseConsumers (rewriter, rootTiledOp, tilingLoops);
556+
557+ // Because we restrict to at most a single tilable consumer for yielding
558+ // a replacement, no new fusion opportunities will yield a replacement,
559+ // meaning there is no need to run consumer fusion again afterwards.
560+ // TODO: run producer and consumer fusion in one worklist.
561+ fuseProducersOfSlices (rewriter, newFusionOpportunities, tileAndFuseOptions,
562+ tilingLoops);
558563 }
559564 if (!tilingLoops.empty ()) {
560565 if (tilingLoops.size () != 1 || !isa<scf::ForallOp>(tilingLoops[0 ])) {
@@ -563,35 +568,24 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
563568 return signalPassFailure ();
564569 }
565570
566- auto forallOp = cast<scf::ForallOp>(tilingLoops[0 ]);
567- if (failed (dropUnitDistributedDims (rewriter, forallOp))) {
568- forallOp.emitOpError (" failed to drop unit dimensions" );
571+ auto forallOp =
572+ dropUnitDistributedDims (rewriter, cast<scf::ForallOp>(tilingLoops[0 ]));
573+ if (failed (forallOp)) {
574+ tilingLoops[0 ]->emitOpError (" failed to drop unit dimensions" );
569575 return signalPassFailure ();
570576 }
571577
572- if (rootTiledOp) {
573- auto [newFusionOpportunities, newLoop] =
574- fuseConsumers (rewriter, rootTiledOp);
575-
576- // Because we restrict to at most a single tilable consumer for yielding
577- // a replacement, no new fusion opportunities will yield a replacement,
578- // meaning there is no need to run consumer fusion again afterwards.
579- // TODO: run producer and consumer fusion in one worklist.
580- fuseProducersOfSlices (rewriter, newFusionOpportunities,
581- tileAndFuseOptions, newLoop);
582- forallOp = newLoop;
583- }
584-
585578 // Reorder the workgroups if the strategy is set to `transpose`.
586579 // This just transposes the first two dimensions of the workgroup i.e., the
587580 // #iree.codegen.workgroup_id_x and #iree.codegen.workgroup_id_y.
588581 // Only reorders if the loop bounds are static.
589582 if (transposeWorkgroup) {
590- SmallVector<Attribute> mappingAttrs (forallOp.getMappingAttr ().getValue ());
583+ SmallVector<Attribute> mappingAttrs (
584+ forallOp->getMappingAttr ().getValue ());
591585 int64_t mappingSize = mappingAttrs.size ();
592- if (areAllStaticLoopBounds (forallOp) && mappingSize >= 2 ) {
586+ if (areAllStaticLoopBounds (* forallOp) && mappingSize >= 2 ) {
593587 std::swap (mappingAttrs[mappingSize - 1 ], mappingAttrs[mappingSize - 2 ]);
594- forallOp. setMappingAttr (ArrayAttr::get (context, mappingAttrs));
588+ forallOp-> setMappingAttr (ArrayAttr::get (context, mappingAttrs));
595589 }
596590 }
597591 }
0 commit comments