@@ -2600,6 +2600,65 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) {
26002600 return success (folded);
26012601}
26022602
2603+ // / Returns constant trip count in trivial cases.
2604+ static std::optional<uint64_t > getTrivialConstantTripCount (AffineForOp forOp) {
2605+ int64_t step = forOp.getStepAsInt ();
2606+ if (!forOp.hasConstantBounds () || step <= 0 )
2607+ return std::nullopt ;
2608+ int64_t lb = forOp.getConstantLowerBound ();
2609+ int64_t ub = forOp.getConstantUpperBound ();
2610+ return ub - lb <= 0 ? 0 : (ub - lb + step - 1 ) / step;
2611+ }
2612+
2613+ // / Fold the empty loop.
2614+ static SmallVector<OpFoldResult> AffineForEmptyLoopFolder (AffineForOp forOp) {
2615+ if (!llvm::hasSingleElement (*forOp.getBody ()))
2616+ return {};
2617+ if (forOp.getNumResults () == 0 )
2618+ return {};
2619+ std::optional<uint64_t > tripCount = getTrivialConstantTripCount (forOp);
2620+ if (tripCount == 0 ) {
2621+ // The initial values of the iteration arguments would be the op's
2622+ // results.
2623+ return forOp.getInits ();
2624+ }
2625+ SmallVector<Value, 4 > replacements;
2626+ auto yieldOp = cast<AffineYieldOp>(forOp.getBody ()->getTerminator ());
2627+ auto iterArgs = forOp.getRegionIterArgs ();
2628+ bool hasValDefinedOutsideLoop = false ;
2629+ bool iterArgsNotInOrder = false ;
2630+ for (unsigned i = 0 , e = yieldOp->getNumOperands (); i < e; ++i) {
2631+ Value val = yieldOp.getOperand (i);
2632+ BlockArgument *iterArgIt = llvm::find (iterArgs, val);
2633+ // TODO: It should be possible to perform a replacement by computing the
2634+ // last value of the IV based on the bounds and the step.
2635+ if (val == forOp.getInductionVar ())
2636+ return {};
2637+ if (iterArgIt == iterArgs.end ()) {
2638+ // `val` is defined outside of the loop.
2639+ assert (forOp.isDefinedOutsideOfLoop (val) &&
2640+ " must be defined outside of the loop" );
2641+ hasValDefinedOutsideLoop = true ;
2642+ replacements.push_back (val);
2643+ } else {
2644+ unsigned pos = std::distance (iterArgs.begin (), iterArgIt);
2645+ if (pos != i)
2646+ iterArgsNotInOrder = true ;
2647+ replacements.push_back (forOp.getInits ()[pos]);
2648+ }
2649+ }
2650+ // Bail out when the trip count is unknown and the loop returns any value
2651+ // defined outside of the loop or any iterArg out of order.
2652+ if (!tripCount.has_value () &&
2653+ (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2654+ return {};
2655+ // Bail out when the loop iterates more than once and it returns any iterArg
2656+ // out of order.
2657+ if (tripCount.has_value () && tripCount.value () >= 2 && iterArgsNotInOrder)
2658+ return {};
2659+ return llvm::to_vector_of<OpFoldResult>(replacements);
2660+ }
2661+
26032662// / Canonicalize the bounds of the given loop.
26042663static LogicalResult canonicalizeLoopBounds (AffineForOp forOp) {
26052664 SmallVector<Value, 4 > lbOperands (forOp.getLowerBoundOperands ());
@@ -2631,79 +2690,30 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
26312690 return success ();
26322691}
26332692
2634- namespace {
2635- // / Returns constant trip count in trivial cases.
2636- static std::optional<uint64_t > getTrivialConstantTripCount (AffineForOp forOp) {
2637- int64_t step = forOp.getStepAsInt ();
2638- if (!forOp.hasConstantBounds () || step <= 0 )
2639- return std::nullopt ;
2640- int64_t lb = forOp.getConstantLowerBound ();
2641- int64_t ub = forOp.getConstantUpperBound ();
2642- return ub - lb <= 0 ? 0 : (ub - lb + step - 1 ) / step;
2693+ // / Returns true if the affine.for has zero iterations in trivial cases.
2694+ static bool hasTrivialZeroTripCount (AffineForOp op) {
2695+ return getTrivialConstantTripCount (op) == 0 ;
26432696}
26442697
2645- // / This is a pattern to fold trivially empty loop bodies.
2646- // / TODO: This should be moved into the folding hook.
2647- struct AffineForEmptyLoopFolder : public OpRewritePattern <AffineForOp> {
2648- using OpRewritePattern<AffineForOp>::OpRewritePattern;
2649-
2650- LogicalResult matchAndRewrite (AffineForOp forOp,
2651- PatternRewriter &rewriter) const override {
2652- // Check that the body only contains a yield.
2653- if (!llvm::hasSingleElement (*forOp.getBody ()))
2654- return failure ();
2655- if (forOp.getNumResults () == 0 )
2656- return success ();
2657- std::optional<uint64_t > tripCount = getTrivialConstantTripCount (forOp);
2658- if (tripCount == 0 ) {
2659- // The initial values of the iteration arguments would be the op's
2660- // results.
2661- rewriter.replaceOp (forOp, forOp.getInits ());
2662- return success ();
2663- }
2664- SmallVector<Value, 4 > replacements;
2665- auto yieldOp = cast<AffineYieldOp>(forOp.getBody ()->getTerminator ());
2666- auto iterArgs = forOp.getRegionIterArgs ();
2667- bool hasValDefinedOutsideLoop = false ;
2668- bool iterArgsNotInOrder = false ;
2669- for (unsigned i = 0 , e = yieldOp->getNumOperands (); i < e; ++i) {
2670- Value val = yieldOp.getOperand (i);
2671- auto *iterArgIt = llvm::find (iterArgs, val);
2672- // TODO: It should be possible to perform a replacement by computing the
2673- // last value of the IV based on the bounds and the step.
2674- if (val == forOp.getInductionVar ())
2675- return failure ();
2676- if (iterArgIt == iterArgs.end ()) {
2677- // `val` is defined outside of the loop.
2678- assert (forOp.isDefinedOutsideOfLoop (val) &&
2679- " must be defined outside of the loop" );
2680- hasValDefinedOutsideLoop = true ;
2681- replacements.push_back (val);
2682- } else {
2683- unsigned pos = std::distance (iterArgs.begin (), iterArgIt);
2684- if (pos != i)
2685- iterArgsNotInOrder = true ;
2686- replacements.push_back (forOp.getInits ()[pos]);
2687- }
2688- }
2689- // Bail out when the trip count is unknown and the loop returns any value
2690- // defined outside of the loop or any iterArg out of order.
2691- if (!tripCount.has_value () &&
2692- (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2693- return failure ();
2694- // Bail out when the loop iterates more than once and it returns any iterArg
2695- // out of order.
2696- if (tripCount.has_value () && tripCount.value () >= 2 && iterArgsNotInOrder)
2697- return failure ();
2698- rewriter.replaceOp (forOp, replacements);
2699- return success ();
2698+ LogicalResult AffineForOp::fold (FoldAdaptor adaptor,
2699+ SmallVectorImpl<OpFoldResult> &results) {
2700+ bool folded = succeeded (foldLoopBounds (*this ));
2701+ folded |= succeeded (canonicalizeLoopBounds (*this ));
2702+ if (hasTrivialZeroTripCount (*this ) && getNumResults () != 0 ) {
2703+ // The initial values of the loop-carried variables (iter_args) are the
2704+ // results of the op. But this must be avoided for an affine.for op that
2705+ // does not return any results. Since ops that do not return results cannot
2706+ // be folded away, we would enter an infinite loop of folds on the same
2707+ // affine.for op.
2708+ results.assign (getInits ().begin (), getInits ().end ());
2709+ folded = true ;
27002710 }
2701- } ;
2702- } // namespace
2703-
2704- void AffineForOp::getCanonicalizationPatterns (RewritePatternSet &results,
2705- MLIRContext *context) {
2706- results. add <AffineForEmptyLoopFolder>(context );
2711+ SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder (* this ) ;
2712+ if (!foldResults. empty ()) {
2713+ results. assign (foldResults);
2714+ folded = true ;
2715+ }
2716+ return success (folded );
27072717}
27082718
27092719OperandRange AffineForOp::getEntrySuccessorOperands (RegionBranchPoint point) {
@@ -2746,27 +2756,6 @@ void AffineForOp::getSuccessorRegions(
27462756 regions.push_back (RegionSuccessor (getResults ()));
27472757}
27482758
2749- // / Returns true if the affine.for has zero iterations in trivial cases.
2750- static bool hasTrivialZeroTripCount (AffineForOp op) {
2751- return getTrivialConstantTripCount (op) == 0 ;
2752- }
2753-
2754- LogicalResult AffineForOp::fold (FoldAdaptor adaptor,
2755- SmallVectorImpl<OpFoldResult> &results) {
2756- bool folded = succeeded (foldLoopBounds (*this ));
2757- folded |= succeeded (canonicalizeLoopBounds (*this ));
2758- if (hasTrivialZeroTripCount (*this ) && getNumResults () != 0 ) {
2759- // The initial values of the loop-carried variables (iter_args) are the
2760- // results of the op. But this must be avoided for an affine.for op that
2761- // does not return any results. Since ops that do not return results cannot
2762- // be folded away, we would enter an infinite loop of folds on the same
2763- // affine.for op.
2764- results.assign (getInits ().begin (), getInits ().end ());
2765- folded = true ;
2766- }
2767- return success (folded);
2768- }
2769-
27702759AffineBound AffineForOp::getLowerBound () {
27712760 return AffineBound (*this , getLowerBoundOperands (), getLowerBoundMap ());
27722761}
0 commit comments