@@ -15510,6 +15510,78 @@ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
1551015510  return SE.getConstant(*ExprVal + DivisorVal - Rem);
1551115511}
1551215512
15513+ static bool collectDivisibilityInformation(
15514+     ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15515+     DenseMap<const SCEV *, const SCEV *> &DivInfo,
15516+     DenseMap<const SCEV *, APInt> &Multiples, ScalarEvolution &SE) {
15517+   // If we have LHS == 0, check if LHS is computing a property of some unknown
15518+   // SCEV %v which we can rewrite %v to express explicitly.
15519+   if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15520+     return false;
15521+   // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15522+   // explicitly express that.
15523+   const SCEVUnknown *URemLHS = nullptr;
15524+   const SCEV *URemRHS = nullptr;
15525+   if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15526+     return false;
15527+ 
15528+   const SCEV *Multiple =
15529+       SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15530+   DivInfo[URemLHS] = Multiple;
15531+   if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15532+     Multiples[URemLHS] = C->getAPInt();
15533+   return true;
15534+ }
15535+ 
15536+ // Check if the condition is a divisibility guard (A % B == 0).
15537+ static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15538+                                 ScalarEvolution &SE) {
15539+   const SCEV *X, *Y;
15540+   return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15541+ }
15542+ 
15543+ // Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15544+ // recursively. This is done by aligning up/down the constant value to the
15545+ // Divisor.
15546+ static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15547+                                                  APInt Divisor,
15548+                                                  ScalarEvolution &SE) {
15549+   // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15550+   // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15551+   // the non-constant operand and in \p LHS the constant operand.
15552+   auto IsMinMaxSCEVWithNonNegativeConstant =
15553+       [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15554+           const SCEV *&RHS) {
15555+         if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15556+           if (MinMax->getNumOperands() != 2)
15557+             return false;
15558+           if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15559+             if (C->getAPInt().isNegative())
15560+               return false;
15561+             SCTy = MinMax->getSCEVType();
15562+             LHS = MinMax->getOperand(0);
15563+             RHS = MinMax->getOperand(1);
15564+             return true;
15565+           }
15566+         }
15567+         return false;
15568+       };
15569+ 
15570+   const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15571+   SCEVTypes SCTy;
15572+   if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15573+                                            MinMaxRHS))
15574+     return MinMaxExpr;
15575+   auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15576+   assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15577+   auto *DivisibleExpr =
15578+       IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15579+             : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15580+   SmallVector<const SCEV *> Ops = {
15581+       applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15582+   return SE.getMinMaxExpr(SCTy, Ops);
15583+ }
15584+ 
1551315585void ScalarEvolution::LoopGuards::collectFromBlock(
1551415586    ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1551515587    const BasicBlock *Block, const BasicBlock *Pred,
@@ -15520,19 +15592,13 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1552015592  SmallVector<const SCEV *> ExprsToRewrite;
1552115593  auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1552215594                              const SCEV *RHS,
15523-                               DenseMap<const SCEV *, const SCEV *>
15524-                                   &RewriteMap ) {
15595+                               DenseMap<const SCEV *, const SCEV *> &RewriteMap, 
15596+                               const LoopGuards &DivGuards ) {
1552515597    // WARNING: It is generally unsound to apply any wrap flags to the proposed
1552615598    // replacement SCEV which isn't directly implied by the structure of that
1552715599    // SCEV.  In particular, using contextual facts to imply flags is *NOT*
1552815600    // legal.  See the scoping rules for flags in the header to understand why.
1552915601
15530-     // If LHS is a constant, apply information to the other expression.
15531-     if (isa<SCEVConstant>(LHS)) {
15532-       std::swap(LHS, RHS);
15533-       Predicate = CmpInst::getSwappedPredicate(Predicate);
15534-     }
15535- 
1553615602    // Check for a condition of the form (-C1 + X < C2).  InstCombine will
1553715603    // create this form when combining two checks of the form (X u< C2 + C1) and
1553815604    // (X >=u C1).
@@ -15565,67 +15631,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1556515631    if (MatchRangeCheckIdiom())
1556615632      return;
1556715633
15568-     // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15569-     // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15570-     // the non-constant operand and in \p LHS the constant operand.
15571-     auto IsMinMaxSCEVWithNonNegativeConstant =
15572-         [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15573-             const SCEV *&RHS) {
15574-           const APInt *C;
15575-           SCTy = Expr->getSCEVType();
15576-           return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
15577-                  match(LHS, m_scev_APInt(C)) && C->isNonNegative();
15578-         };
15579- 
15580-     // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15581-     // recursively. This is done by aligning up/down the constant value to the
15582-     // Divisor.
15583-     std::function<const SCEV *(const SCEV *, const SCEV *)>
15584-         ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15585-                                            const SCEV *Divisor) {
15586-           auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15587-           if (!ConstDivisor)
15588-             return MinMaxExpr;
15589-           const APInt &DivisorVal = ConstDivisor->getAPInt();
15590- 
15591-           const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15592-           SCEVTypes SCTy;
15593-           if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15594-                                                    MinMaxRHS))
15595-             return MinMaxExpr;
15596-           auto IsMin =
15597-               isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15598-           assert(SE.isKnownNonNegative(MinMaxLHS) &&
15599-                  "Expected non-negative operand!");
15600-           auto *DivisibleExpr =
15601-               IsMin
15602-                   ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
15603-                   : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
15604-           SmallVector<const SCEV *> Ops = {
15605-               ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15606-           return SE.getMinMaxExpr(SCTy, Ops);
15607-         };
15608- 
15609-     // If we have LHS == 0, check if LHS is computing a property of some unknown
15610-     // SCEV %v which we can rewrite %v to express explicitly.
15611-     if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15612-       // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15613-       // explicitly express that.
15614-       const SCEVUnknown *URemLHS = nullptr;
15615-       const SCEV *URemRHS = nullptr;
15616-       if (match(LHS,
15617-                 m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15618-         auto I = RewriteMap.find(URemLHS);
15619-         const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15620-         RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15621-         const auto *Multiple =
15622-             SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15623-         RewriteMap[URemLHS] = Multiple;
15624-         ExprsToRewrite.push_back(URemLHS);
15625-         return;
15626-       }
15627-     }
15628- 
1562915634    // Do not apply information for constants or if RHS contains an AddRec.
1563015635    if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1563115636      return;
@@ -15655,7 +15660,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1565515660    };
1565615661
1565715662    const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15658-     const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
15663+     // Apply divisibility information when computing the constant multiple.
15664+     const APInt &DividesBy =
15665+         SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1565915666
1566015667    // Collect rewrites for LHS and its transitive operands based on the
1566115668    // condition.
@@ -15840,8 +15847,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1584015847
1584115848  // Now apply the information from the collected conditions to
1584215849  // Guards.RewriteMap. Conditions are processed in reverse order, so the
15843-   // earliest conditions is processed first. This ensures the SCEVs with the
15850+   // earliest conditions is processed first, except guards with divisibility
15851+   // information, which are moved to the back. This ensures the SCEVs with the
1584415852  // shortest dependency chains are constructed first.
15853+   SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15854+       GuardsToProcess;
1584515855  for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1584615856    SmallVector<Value *, 8> Worklist;
1584715857    SmallPtrSet<Value *, 8> Visited;
@@ -15856,7 +15866,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1585615866            EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1585715867        const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1585815868        const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15859-         CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15869+         // If LHS is a constant, apply information to the other expression.
15870+         // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15871+         // can improve results.
15872+         if (isa<SCEVConstant>(LHS)) {
15873+           std::swap(LHS, RHS);
15874+           Predicate = CmpInst::getSwappedPredicate(Predicate);
15875+         }
15876+         GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1586015877        continue;
1586115878      }
1586215879
@@ -15869,6 +15886,31 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1586915886    }
1587015887  }
1587115888
15889+   // Process divisibility guards in reverse order to populate DivGuards early.
15890+   DenseMap<const SCEV *, APInt> Multiples;
15891+   LoopGuards DivGuards(SE);
15892+   for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15893+     if (!isDivisibilityGuard(LHS, RHS, SE))
15894+       continue;
15895+     collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15896+                                    Multiples, SE);
15897+   }
15898+ 
15899+   for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15900+     CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15901+ 
15902+   // Apply divisibility information last. This ensures it is applied to the
15903+   // outermost expression after other rewrites for the given value.
15904+   for (const auto &[K, Divisor] : Multiples) {
15905+     const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15906+     Guards.RewriteMap[K] =
15907+         SE.getMulExpr(SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(
15908+                                          Guards.rewrite(K), Divisor, SE),
15909+                                      DivisorSCEV),
15910+                       DivisorSCEV);
15911+     ExprsToRewrite.push_back(K);
15912+   }
15913+ 
1587215914  // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1587315915  // the replacement expressions are contained in the ranges of the replaced
1587415916  // expressions.
0 commit comments