@@ -15740,19 +15740,26 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1574015740 GetNextSCEVDividesByDivisor(One, DividesBy);
1574115741 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
1574215742 } else {
15743+ // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15744+ // but creating the subtraction eagerly is expensive. Track the
15745+ // inequalities in a separate map, and materialize the rewrite lazily
15746+ // when encountering a suitable subtraction while re-writing.
1574315747 if (LHS->getType()->isPointerTy()) {
1574415748 LHS = SE.getLosslessPtrToIntExpr(LHS);
1574515749 RHS = SE.getLosslessPtrToIntExpr(RHS);
1574615750 if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS))
1574715751 break;
1574815752 }
15749- auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) {
15750- const SCEV *Sub = SE.getMinusSCEV(A, B);
15751- AddRewrite(Sub, Sub,
15752- SE.getUMaxExpr(Sub, SE.getOne(From->getType())));
15753- };
15754- AddSubRewrite(LHS, RHS);
15755- AddSubRewrite(RHS, LHS);
15753+ const SCEVConstant *C;
15754+ const SCEV *A, *B;
15755+ if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) &&
15756+ match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) {
15757+ RHS = A;
15758+ LHS = B;
15759+ }
15760+ if (LHS > RHS)
15761+ std::swap(LHS, RHS);
15762+ Guards.NotEqual.insert({LHS, RHS});
1575615763 continue;
1575715764 }
1575815765 break;
@@ -15886,13 +15893,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1588615893 class SCEVLoopGuardRewriter
1588715894 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
1588815895 const DenseMap<const SCEV *, const SCEV *> ⤅
15896+ const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> ≠
1588915897
1589015898 SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
1589115899
1589215900 public:
1589315901 SCEVLoopGuardRewriter(ScalarEvolution &SE,
1589415902 const ScalarEvolution::LoopGuards &Guards)
15895- : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15903+ : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
15904+ NotEqual(Guards.NotEqual) {
1589615905 if (Guards.PreserveNUW)
1589715906 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
1589815907 if (Guards.PreserveNSW)
@@ -15947,14 +15956,36 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1594715956 }
1594815957
1594915958 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15959+ // Helper to check if S is a subtraction (A - B) where A != B, and if so,
15960+ // return UMax(S, 1).
15961+ auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
15962+ const SCEV *LHS, *RHS;
15963+ if (MatchBinarySub(S, LHS, RHS)) {
15964+ if (LHS > RHS)
15965+ std::swap(LHS, RHS);
15966+ if (NotEqual.contains({LHS, RHS}))
15967+ return SE.getUMaxExpr(S, SE.getOne(S->getType()));
15968+ }
15969+ return nullptr;
15970+ };
15971+
15972+ // Check if Expr itself is a subtraction pattern with guard info.
15973+ if (const SCEV *Rewritten = RewriteSubtraction(Expr))
15974+ return Rewritten;
15975+
1595015976 // Trip count expressions sometimes consist of adding 3 operands, i.e.
1595115977 // (Const + A + B). There may be guard info for A + B, and if so, apply
1595215978 // it.
1595315979 // TODO: Could more generally apply guards to Add sub-expressions.
1595415980 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
1595515981 Expr->getNumOperands() == 3) {
15956- if (const SCEV *S = Map.lookup(
15957- SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
15982+ const SCEV *Add =
15983+ SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
15984+ if (const SCEV *Rewritten = RewriteSubtraction(Add))
15985+ return SE.getAddExpr(
15986+ Expr->getOperand(0), Rewritten,
15987+ ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
15988+ if (const SCEV *S = Map.lookup(Add))
1595815989 return SE.getAddExpr(Expr->getOperand(0), S);
1595915990 }
1596015991 SmallVector<const SCEV *, 2> Operands;
@@ -15989,7 +16020,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1598916020 }
1599016021 };
1599116022
15992- if (RewriteMap.empty())
16023+ if (RewriteMap.empty() && NotEqual.empty() )
1599316024 return Expr;
1599416025
1599516026 SCEVLoopGuardRewriter Rewriter(SE, *this);
0 commit comments