@@ -1767,6 +1767,32 @@ struct ForallOpSingleOrZeroIterationDimsFolder
17671767 }
17681768};
17691769
1770+ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern <ForallOp> {
1771+ using OpRewritePattern<ForallOp>::OpRewritePattern;
1772+
1773+ LogicalResult matchAndRewrite (ForallOp op,
1774+ PatternRewriter &rewriter) const override {
1775+ // Replace all induction vars with a single trip count with their lower
1776+ // bound.
1777+ Location loc = op.getLoc ();
1778+ bool replacedIv = false ;
1779+ for (auto [lb, ub, step, iv] :
1780+ llvm::zip (op.getMixedLowerBound (), op.getMixedUpperBound (),
1781+ op.getMixedStep (), op.getInductionVars ())) {
1782+ if (iv.getUses ().begin () == iv.getUses ().end ())
1783+ continue ;
1784+ auto numIterations = constantTripCount (lb, ub, step);
1785+ if (!numIterations.has_value () || numIterations.value () != 1 ) {
1786+ continue ;
1787+ }
1788+ rewriter.replaceAllUsesWith (
1789+ iv, getValueOrCreateConstantIndexOp (rewriter, loc, lb));
1790+ return success ();
1791+ }
1792+ return failure ();
1793+ }
1794+ };
1795+
17701796struct FoldTensorCastOfOutputIntoForallOp
17711797 : public OpRewritePattern<scf::ForallOp> {
17721798 using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
@@ -1851,7 +1877,8 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
18511877 MLIRContext *context) {
18521878 results.add <DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
18531879 ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
1854- ForallOpSingleOrZeroIterationDimsFolder>(context);
1880+ ForallOpSingleOrZeroIterationDimsFolder,
1881+ ForallOpReplaceConstantInductionVar>(context);
18551882}
18561883
18571884// / Given the region at `index`, or the parent operation if `index` is None,
0 commit comments