@@ -285,19 +285,6 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
285285 Location loc = forOp.getLoc ();
286286 SmallVector<Value> predicates (maxStage);
287287 for (int64_t i = 0 ; i < maxStage; i++) {
288- if (dynamicLoop) {
289- Type t = ub.getType ();
290- // pred = ub > lb + (i * step)
291- Value iv = rewriter.create <arith::AddIOp>(
292- loc, lb,
293- rewriter.create <arith::MulIOp>(
294- loc, step,
295- rewriter.create <arith::ConstantOp>(
296- loc, rewriter.getIntegerAttr (t, i))));
297- predicates[i] = rewriter.create <arith::CmpIOp>(
298- loc, arith::CmpIPredicate::slt, iv, ub);
299- }
300-
301288 // special handling for induction variable as the increment is implicit.
302289 // iv = lb + i * step
303290 Type t = lb.getType ();
@@ -308,6 +295,13 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
308295 rewriter.create <arith::ConstantOp>(loc,
309296 rewriter.getIntegerAttr (t, i))));
310297 setValueMapping (forOp.getInductionVar (), iv, i);
298+
299+ if (dynamicLoop) {
300+ // pred = ub > lb + (i * step)
301+ predicates[i] = rewriter.create <arith::CmpIOp>(
302+ loc, arith::CmpIPredicate::slt, iv, ub);
303+ }
304+
311305 for (Operation *op : opOrder) {
312306 if (stages[op] > i)
313307 continue ;
@@ -655,50 +649,56 @@ LogicalResult
655649LoopPipelinerInternal::emitEpilogue (RewriterBase &rewriter,
656650 llvm::SmallVector<Value> &returnValues) {
657651 Location loc = forOp.getLoc ();
652+ Type t = lb.getType ();
658653 // Emit different versions of the induction variable. They will be
659654 // removed by dead code if not used.
660655
661- // range_diff = ub - lb
662- // total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
663- Type t = lb.getType ();
664- Value zero =
665- rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, 0 ));
666- Value one =
667- rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, 1 ));
668- Value minusOne =
669- rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
656+ auto createConst = [&](int v) {
657+ return rewriter.create <arith::ConstantOp>(loc,
658+ rewriter.getIntegerAttr (t, v));
659+ };
660+
661+ // total_iterations = cdiv(range_diff, step);
662+ // - range_diff = ub - lb
663+ // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
664+ Value zero = createConst (0 );
665+ Value one = createConst (1 );
670666 Value stepLessZero = rewriter.create <arith::CmpIOp>(
671667 loc, arith::CmpIPredicate::slt, step, zero);
672668 Value stepDecr =
673- rewriter.create <arith::SelectOp>(loc, stepLessZero, one, minusOne );
669+ rewriter.create <arith::SelectOp>(loc, stepLessZero, one, createConst (- 1 ) );
674670
675671 Value rangeDiff = rewriter.create <arith::SubIOp>(loc, ub, lb);
676672 Value rangeIncrStep = rewriter.create <arith::AddIOp>(loc, rangeDiff, step);
677673 Value rangeDecr =
678674 rewriter.create <arith::AddIOp>(loc, rangeIncrStep, stepDecr);
679675 Value totalIterations = rewriter.create <arith::DivSIOp>(loc, rangeDecr, step);
680676
677+ // If total_iters < max_stage, start the epilogue at zero to match the
678+ // ramp-up in the prologue.
679+ // start_iter = max(0, total_iters - max_stage)
680+ Value iterI = rewriter.create <arith::SubIOp>(loc, totalIterations,
681+ createConst (maxStage));
682+ iterI = rewriter.create <arith::MaxSIOp>(loc, zero, iterI);
683+
681684 // Capture predicates for dynamic loops.
682685 SmallVector<Value> predicates (maxStage + 1 );
683686
684- for (int64_t i = 0 ; i < maxStage; i++) {
685- // iterI = total_iters - 1 - i
686- // May go negative...
687- Value minusI =
688- rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -i));
689- Value iterI = rewriter.create <arith::AddIOp>(
690- loc, rewriter.create <arith::AddIOp>(loc, totalIterations, minusOne),
691- minusI);
687+ for (int64_t i = 1 ; i <= maxStage; i++) {
692688 // newLastIter = lb + step * iterI
693689 Value newlastIter = rewriter.create <arith::AddIOp>(
694690 loc, lb, rewriter.create <arith::MulIOp>(loc, step, iterI));
695691
696- setValueMapping (forOp.getInductionVar (), newlastIter, maxStage - i);
692+ setValueMapping (forOp.getInductionVar (), newlastIter, i);
693+
694+ // increment to next iterI
695+ iterI = rewriter.create <arith::AddIOp>(loc, iterI, one);
697696
698697 if (dynamicLoop) {
699- // pred = iterI >= 0
700- predicates[i + 1 ] = rewriter.create <arith::CmpIOp>(
701- loc, arith::CmpIPredicate::sge, iterI, zero);
698+ // Disable stages when `i` is greater than total_iters.
699+ // pred = total_iters >= i
700+ predicates[i] = rewriter.create <arith::CmpIOp>(
701+ loc, arith::CmpIPredicate::sge, totalIterations, createConst (i));
702702 }
703703 }
704704
0 commit comments