@@ -2610,6 +2610,21 @@ static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
26102610 return ub - lb <= 0 ? 0 : (ub - lb + step - 1 ) / step;
26112611}
26122612
2613+ // / Calculate the constant value of the loop's induction variable for its last
2614+ // / trip.
2615+ static std::optional<int64_t >
2616+ getConstantInductionVarForLastTrip (AffineForOp forOp) {
2617+ std::optional<uint64_t > tripCount = getTrivialConstantTripCount (forOp);
2618+ if (!tripCount.has_value ())
2619+ return std::nullopt ;
2620+ if (tripCount.value () == 0 )
2621+ return std::nullopt ;
2622+ int64_t lb = forOp.getConstantLowerBound ();
2623+ int64_t step = forOp.getStepAsInt ();
2624+ int64_t lastTripIv = lb + (tripCount.value () - 1 ) * step;
2625+ return lastTripIv;
2626+ }
2627+
26132628// / Fold the empty loop.
26142629static SmallVector<OpFoldResult> AffineForEmptyLoopFolder (AffineForOp forOp) {
26152630 if (!llvm::hasSingleElement (*forOp.getBody ()))
@@ -2622,18 +2637,22 @@ static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) {
26222637 // results.
26232638 return forOp.getInits ();
26242639 }
2625- SmallVector<Value , 4 > replacements;
2640+ SmallVector<OpFoldResult , 4 > replacements;
26262641 auto yieldOp = cast<AffineYieldOp>(forOp.getBody ()->getTerminator ());
26272642 auto iterArgs = forOp.getRegionIterArgs ();
26282643 bool hasValDefinedOutsideLoop = false ;
26292644 bool iterArgsNotInOrder = false ;
26302645 for (unsigned i = 0 , e = yieldOp->getNumOperands (); i < e; ++i) {
26312646 Value val = yieldOp.getOperand (i);
26322647 BlockArgument *iterArgIt = llvm::find (iterArgs, val);
2633- // TODO: It should be possible to perform a replacement by computing the
2634- // last value of the IV based on the bounds and the step.
2635- if (val == forOp.getInductionVar ())
2648+ if (val == forOp.getInductionVar ()) {
2649+ if (auto lastTripIv = getConstantInductionVarForLastTrip (forOp)) {
2650+ replacements.push_back (IntegerAttr::get (
2651+ IndexType::get (forOp.getContext ()), lastTripIv.value ()));
2652+ continue ;
2653+ }
26362654 return {};
2655+ }
26372656 if (iterArgIt == iterArgs.end ()) {
26382657 // `val` is defined outside of the loop.
26392658 assert (forOp.isDefinedOutsideOfLoop (val) &&
@@ -2656,7 +2675,7 @@ static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) {
26562675 // out of order.
26572676 if (tripCount.has_value () && tripCount.value () >= 2 && iterArgsNotInOrder)
26582677 return {};
2659- return llvm::to_vector_of<OpFoldResult>( replacements) ;
2678+ return replacements;
26602679}
26612680
26622681// / Canonicalize the bounds of the given loop.
0 commit comments