@@ -651,8 +651,7 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
651651 return failure ();
652652 }
653653
654- promote (rewriter, *this );
655- return success ();
654+ return promote (rewriter, *this );
656655}
657656
658657Block::BlockArgListType ForallOp::getRegionIterArgs () {
@@ -664,10 +663,23 @@ MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
664663}
665664
666665// / Promotes the loop body of a scf::ForallOp to its containing block.
667- void mlir::scf::promote (RewriterBase &rewriter, scf::ForallOp forallOp) {
666+ LogicalResult mlir::scf::promote (RewriterBase &rewriter, scf::ForallOp forallOp) {
668667 OpBuilder::InsertionGuard g (rewriter);
669668 scf::InParallelOp terminator = forallOp.getTerminator ();
670669
670+ // Make sure we can promote all parallel combining ops in terminator:
671+ for (auto &yieldingOp : terminator.getYieldingOps ()) {
672+ auto parallelCombiningOp =
673+ dyn_cast<ParallelCombiningOpInterface>(&yieldingOp);
674+ if (!parallelCombiningOp)
675+ return rewriter.notifyMatchFailure (
676+ forallOp, " terminator has non-parallel-combining op" );
677+ if (!parallelCombiningOp.canPromoteInParallelLoop (rewriter))
678+ return rewriter.notifyMatchFailure (
679+ forallOp, " parallel combining op cannot be promoted" );
680+ }
681+
682+
671683 // Replace block arguments with lower bounds (replacements for IVs) and
672684 // outputs.
673685 SmallVector<Value> bbArgReplacements = forallOp.getLowerBound (rewriter);
@@ -683,30 +695,29 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
683695 SmallVector<Value> results;
684696 results.reserve (forallOp.getResults ().size ());
685697 for (auto &yieldingOp : terminator.getYieldingOps ()) {
686- auto parallelInsertSliceOp =
687- dyn_cast<tensor::ParallelInsertSliceOp>( yieldingOp);
688- if (!parallelInsertSliceOp )
698+ auto parallelCombiningOp =
699+ dyn_cast<ParallelCombiningOpInterface>(& yieldingOp);
700+ if (!parallelCombiningOp )
689701 continue ;
690702
691- Value dst = parallelInsertSliceOp.getDest ();
692- Value src = parallelInsertSliceOp.getSource ();
693- if (llvm::isa<TensorType>(src.getType ())) {
694- results.push_back (tensor::InsertSliceOp::create (
695- rewriter, forallOp.getLoc (), dst.getType (), src, dst,
696- parallelInsertSliceOp.getOffsets (), parallelInsertSliceOp.getSizes (),
697- parallelInsertSliceOp.getStrides (),
698- parallelInsertSliceOp.getStaticOffsets (),
699- parallelInsertSliceOp.getStaticSizes (),
700- parallelInsertSliceOp.getStaticStrides ()));
701- } else {
702- llvm_unreachable (" unsupported terminator" );
703- }
703+ assert (parallelCombiningOp.canPromoteInParallelLoop (rewriter));
704+
705+ FailureOr<SmallVector<Value>> promotedValues =
706+ parallelCombiningOp.promoteInParallelLoop (rewriter);
707+ if (failed (promotedValues))
708+ return failure ();
709+
710+ results.append (promotedValues->begin (), promotedValues->end ());
704711 }
712+ if (results.size () != forallOp.getResults ().size ())
713+ return rewriter.notifyMatchFailure (
714+ forallOp, " failed to materialize replacements for all results" );
705715 rewriter.replaceAllUsesWith (forallOp.getResults (), results);
706716
707717 // Erase the old terminator and the loop.
708718 rewriter.eraseOp (terminator);
709719 rewriter.eraseOp (forallOp);
720+ return success ();
710721}
711722
712723LoopNest mlir::scf::buildLoopNest (
@@ -1789,7 +1800,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
17891800
17901801 // All of the loop dimensions perform a single iteration. Inline loop body.
17911802 if (newMixedLowerBounds.empty ()) {
1792- promote (rewriter, op);
1803+ if (failed (promote (rewriter, op)))
1804+ return failure ();
17931805 return success ();
17941806 }
17951807
0 commit comments