@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
9494 RewriterBase &rewriter);
9595 // / Emits the epilogue, this creates `maxStage - 1` part which will contain
9696 // / operations from stages [i; maxStage], where i is the part index.
97- void emitEpilogue (RewriterBase &rewriter,
98- llvm::SmallVector<Value> &returnValues);
97+ LogicalResult emitEpilogue (RewriterBase &rewriter,
98+ llvm::SmallVector<Value> &returnValues);
9999};
100100
101101bool LoopPipelinerInternal::initializeLoopInfo (
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
133133 LDBG (" --no epilogue or predicate set -> BAIL" );
134134 return false ;
135135 }
136- if (dynamicLoop && peelEpilogue) {
137- LDBG (" --dynamic loop doesn't support epilogue yet -> BAIL" );
138- return false ;
139- }
140136 std::vector<std::pair<Operation *, unsigned >> schedule;
141137 options.getScheduleFn (forOp, schedule);
142138 if (schedule.empty ()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
313309 });
314310 int predicateIdx = i - stages[op];
315311 if (predicates[predicateIdx]) {
312+ OpBuilder::InsertionGuard insertGuard (rewriter);
316313 newOp = predicateFn (rewriter, newOp, predicates[predicateIdx]);
317314 assert (newOp && " failed to predicate op." );
318315 }
319- rewriter.setInsertionPointAfter (newOp);
320316 if (annotateFn)
321317 annotateFn (newOp, PipeliningOption::PipelinerPart::Prologue, i);
322318 for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
@@ -561,14 +557,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
561557 }
562558
563559 if (predicates[useStage]) {
560+ OpBuilder::InsertionGuard insertGuard (rewriter);
564561 newOp = predicateFn (rewriter, newOp, predicates[useStage]);
565562 if (!newOp)
566563 return failure ();
567564 // Remap the results to the new predicated one.
568565 for (auto values : llvm::zip (op->getResults (), newOp->getResults ()))
569566 mapping.map (std::get<0 >(values), std::get<1 >(values));
570567 }
571- rewriter.setInsertionPointAfter (newOp);
572568 if (annotateFn)
573569 annotateFn (newOp, PipeliningOption::PipelinerPart::Kernel, 0 );
574570 }
@@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel(
640636 return success ();
641637}
642638
643- void LoopPipelinerInternal::emitEpilogue (
644- RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
639+ LogicalResult
640+ LoopPipelinerInternal::emitEpilogue (RewriterBase &rewriter,
641+ llvm::SmallVector<Value> &returnValues) {
642+ Location loc = forOp.getLoc ();
645643 // Emit different versions of the induction variable. They will be
646644 // removed by dead code if not used.
645+
646+ // bounds_range = ub - lb
647+ // total_iterations = (bounds_range + step - 1) / step
648+ Type t = lb.getType ();
649+ Value minus1 =
650+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
651+ Value boundsRange = rewriter.create <arith::SubIOp>(loc, ub, lb);
652+ Value rangeIncr = rewriter.create <arith::AddIOp>(loc, boundsRange, step);
653+ Value rangeDecr = rewriter.create <arith::AddIOp>(loc, rangeIncr, minus1);
654+ Value totalIterations = rewriter.create <arith::DivUIOp>(loc, rangeDecr, step);
655+
656+ SmallVector<Value> predicates (maxStage + 1 );
647657 for (int64_t i = 0 ; i < maxStage; i++) {
648- Location loc = forOp.getLoc ();
649- Type t = lb.getType ();
650- Value minusOne =
651- rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
652- // number of iterations = ((ub - 1) - lb) / step
653- Value totalNumIteration = rewriter.create <arith::DivUIOp>(
654- loc,
655- rewriter.create <arith::SubIOp>(
656- loc, rewriter.create <arith::AddIOp>(loc, ub, minusOne), lb),
657- step);
658- // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
658+ // iterI = total_iters - 1 - i
659+ // May go negative...
659660 Value minusI =
660661 rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -i));
662+ Value iterI = rewriter.create <arith::AddIOp>(
663+ loc, rewriter.create <arith::AddIOp>(loc, totalIterations, minus1),
664+ minusI);
665+ // newLastIter = lb + step * iterI
661666 Value newlastIter = rewriter.create <arith::AddIOp>(
662- loc, lb,
663- rewriter.create <arith::MulIOp>(
664- loc, step,
665- rewriter.create <arith::AddIOp>(loc, totalNumIteration, minusI)));
667+ loc, lb, rewriter.create <arith::MulIOp>(loc, step, iterI));
668+
666669 setValueMapping (forOp.getInductionVar (), newlastIter, maxStage - i);
670+
671+ if (dynamicLoop) {
672+ // pred = iterI >= lb
673+ predicates[i + 1 ] = rewriter.create <arith::CmpIOp>(
674+ loc, arith::CmpIPredicate::sge, iterI, lb);
675+ }
667676 }
677+
668678 // Emit `maxStage - 1` epilogue part that includes operations from stages
669679 // [i; maxStage].
670680 for (int64_t i = 1 ; i <= maxStage; i++) {
681+ SmallVector<std::pair<Value, unsigned >> returnMap (returnValues.size ());
671682 for (Operation *op : opOrder) {
672683 if (stages[op] < i)
673684 continue ;
685+ unsigned currentVersion = maxStage - stages[op] + i;
686+ unsigned nextVersion = currentVersion + 1 ;
674687 Operation *newOp =
675688 cloneAndUpdateOperands (rewriter, op, [&](OpOperand *newOperand) {
676689 auto it = valueMapping.find (newOperand->get ());
677690 if (it != valueMapping.end ()) {
678- Value replacement = it->second [maxStage - stages[op] + i ];
691+ Value replacement = it->second [currentVersion ];
679692 newOperand->set (replacement);
680693 }
681694 });
695+ if (dynamicLoop) {
696+ OpBuilder::InsertionGuard insertGuard (rewriter);
697+ newOp = predicateFn (rewriter, newOp, predicates[currentVersion]);
698+ if (!newOp)
699+ return failure ();
700+ }
682701 if (annotateFn)
683702 annotateFn (newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1 );
684- for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
685- setValueMapping (op->getResult (destId), newOp->getResult (destId),
686- maxStage - stages[op] + i);
703+
704+ for (auto [opRes, newRes] :
705+ llvm::zip (op->getResults (), newOp->getResults ())) {
706+ setValueMapping (opRes, newRes, currentVersion);
687707 // If the value is a loop carried dependency update the loop argument
688708 // mapping and keep track of the last version to replace the original
689709 // forOp uses.
690710 for (OpOperand &operand :
691711 forOp.getBody ()->getTerminator ()->getOpOperands ()) {
692- if (operand.get () != op-> getResult (destId) )
712+ if (operand.get () != opRes )
693713 continue ;
694- unsigned version = maxStage - stages[op] + i + 1 ;
695714 // If the version is greater than maxStage it means it maps to the
696715 // original forOp returned value.
697- if (version > maxStage) {
698- returnValues[operand.getOperandNumber ()] = newOp->getResult (destId);
699- continue ;
700- }
701- setValueMapping (forOp.getRegionIterArgs ()[operand.getOperandNumber ()],
702- newOp->getResult (destId), version);
716+ unsigned ri = operand.getOperandNumber ();
717+ returnValues[ri] = newRes;
718+ Value mapVal = forOp.getRegionIterArgs ()[ri];
719+ returnMap[ri] = std::make_pair (mapVal, currentVersion);
720+ if (nextVersion <= maxStage)
721+ setValueMapping (mapVal, newRes, nextVersion);
722+ }
723+ }
724+ }
725+ if (dynamicLoop) {
726+ // Select return values from this stage (live outs) based on predication.
727+ // If the stage is valid select the peeled value, else use previous stage
728+ // value.
729+ for (auto pair : llvm::enumerate (returnValues)) {
730+ unsigned ri = pair.index ();
731+ auto [mapVal, currentVersion] = returnMap[ri];
732+ if (mapVal) {
733+ unsigned nextVersion = currentVersion + 1 ;
734+ Value pred = predicates[currentVersion];
735+ Value prevValue = valueMapping[mapVal][currentVersion];
736+ auto selOp = rewriter.create <arith::SelectOp>(loc, pred, pair.value (),
737+ prevValue);
738+ returnValues[ri] = selOp;
739+ if (nextVersion <= maxStage)
740+ setValueMapping (mapVal, selOp, nextVersion);
703741 }
704742 }
705743 }
706744 }
745+ return success ();
707746}
708747
709748void LoopPipelinerInternal::setValueMapping (Value key, Value el, int64_t idx) {
@@ -760,7 +799,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
760799 if (options.peelEpilogue ) {
761800 // 4. Emit the epilogue after the new forOp.
762801 rewriter.setInsertionPointAfter (newForOp);
763- pipeliner.emitEpilogue (rewriter, returnValues);
802+ if (failed (pipeliner.emitEpilogue (rewriter, returnValues)))
803+ return failure ();
764804 }
765805 // 5. Erase the original loop and replace the uses with the epilogue output.
766806 if (forOp->getNumResults () > 0 )
0 commit comments