@@ -2610,6 +2610,19 @@ static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
26102610 return ub - lb <= 0 ? 0 : (ub - lb + step - 1 ) / step;
26112611}
26122612
2613+ // / Calculate the constant value of the loop's induction variable for its last
2614+ // / trip, construct an OpFoldResult using this value and return it.
2615+ static OpFoldResult getConstantInductionVarForLastTrip (AffineForOp forOp) {
2616+ std::optional<uint64_t > tripCount = getTrivialConstantTripCount (forOp);
2617+ if (!tripCount.has_value ())
2618+ return {};
2619+ int64_t lb = forOp.getConstantLowerBound ();
2620+ int64_t step = forOp.getStepAsInt ();
2621+ int64_t lastTripIv = lb + (tripCount.value () - 1 ) * step;
2622+ return OpFoldResult (
2623+ IntegerAttr::get (IndexType::get (forOp.getContext ()), lastTripIv));
2624+ }
2625+
26132626// / Fold the empty loop.
26142627static SmallVector<OpFoldResult> AffineForEmptyLoopFolder (AffineForOp forOp) {
26152628 if (!llvm::hasSingleElement (*forOp.getBody ()))
@@ -2622,7 +2635,7 @@ static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) {
26222635 // results.
26232636 return forOp.getInits ();
26242637 }
2625- SmallVector<Value , 4 > replacements;
2638+ SmallVector<OpFoldResult , 4 > replacements;
26262639 auto yieldOp = cast<AffineYieldOp>(forOp.getBody ()->getTerminator ());
26272640 auto iterArgs = forOp.getRegionIterArgs ();
26282641 bool hasValDefinedOutsideLoop = false ;
@@ -2632,8 +2645,15 @@ static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) {
26322645 BlockArgument *iterArgIt = llvm::find (iterArgs, val);
26332646 // TODO: It should be possible to perform a replacement by computing the
26342647 // last value of the IV based on the bounds and the step.
2635- if (val == forOp.getInductionVar ())
2636- return {};
2648+ if (val == forOp.getInductionVar ()) {
2649+ OpFoldResult lastTripIv = getConstantInductionVarForLastTrip (forOp);
2650+ if (lastTripIv) {
2651+ replacements.push_back (lastTripIv);
2652+ continue ;
2653+ } else {
2654+ return {};
2655+ }
2656+ }
26372657 if (iterArgIt == iterArgs.end ()) {
26382658 // `val` is defined outside of the loop.
26392659 assert (forOp.isDefinedOutsideOfLoop (val) &&
@@ -2656,7 +2676,7 @@ static SmallVector<OpFoldResult> AffineForEmptyLoopFolder(AffineForOp forOp) {
26562676 // out of order.
26572677 if (tripCount.has_value () && tripCount.value () >= 2 && iterArgsNotInOrder)
26582678 return {};
2659- return llvm::to_vector_of<OpFoldResult>( replacements) ;
2679+ return replacements;
26602680}
26612681
26622682// / Canonicalize the bounds of the given loop.
0 commit comments