@@ -15510,6 +15510,78 @@ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
1551015510 return SE.getConstant(*ExprVal + DivisorVal - Rem);
1551115511}
1551215512
15513+ static bool collectDivisibilityInformation(
15514+ ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15515+ DenseMap<const SCEV *, const SCEV *> &DivInfo,
15516+ DenseMap<const SCEV *, APInt> &Multiples, ScalarEvolution &SE) {
15517+ // If we have LHS == 0, check if LHS is computing a property of some unknown
15518+ // SCEV %v which we can rewrite %v to express explicitly.
15519+ if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15520+ return false;
15521+ // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15522+ // explicitly express that.
15523+ const SCEVUnknown *URemLHS = nullptr;
15524+ const SCEV *URemRHS = nullptr;
15525+ if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15526+ return false;
15527+
15528+ const SCEV *Multiple =
15529+ SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15530+ DivInfo[URemLHS] = Multiple;
15531+ if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15532+ Multiples[URemLHS] = C->getAPInt();
15533+ return true;
15534+ }
15535+
15536+ // Check if the condition is a divisibility guard (A % B == 0).
15537+ static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15538+ ScalarEvolution &SE) {
15539+ const SCEV *X, *Y;
15540+ return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15541+ }
15542+
15543+ // Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15544+ // recursively. This is done by aligning up/down the constant value to the
15545+ // Divisor.
15546+ static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15547+ APInt Divisor,
15548+ ScalarEvolution &SE) {
15549+ // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15550+ // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15551+ // the non-constant operand and in \p LHS the constant operand.
15552+ auto IsMinMaxSCEVWithNonNegativeConstant =
15553+ [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15554+ const SCEV *&RHS) {
15555+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15556+ if (MinMax->getNumOperands() != 2)
15557+ return false;
15558+ if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15559+ if (C->getAPInt().isNegative())
15560+ return false;
15561+ SCTy = MinMax->getSCEVType();
15562+ LHS = MinMax->getOperand(0);
15563+ RHS = MinMax->getOperand(1);
15564+ return true;
15565+ }
15566+ }
15567+ return false;
15568+ };
15569+
15570+ const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15571+ SCEVTypes SCTy;
15572+ if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15573+ MinMaxRHS))
15574+ return MinMaxExpr;
15575+ auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15576+ assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15577+ auto *DivisibleExpr =
15578+ IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15579+ : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15580+ SmallVector<const SCEV *> Ops = {
15581+ applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15582+ return SE.getMinMaxExpr(SCTy, Ops);
15583+ }
15584+
1551315585void ScalarEvolution::LoopGuards::collectFromBlock(
1551415586 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1551515587 const BasicBlock *Block, const BasicBlock *Pred,
@@ -15520,19 +15592,13 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1552015592 SmallVector<const SCEV *> ExprsToRewrite;
1552115593 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1552215594 const SCEV *RHS,
15523- DenseMap<const SCEV *, const SCEV *>
15524- &RewriteMap ) {
15595+ DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15596+ const LoopGuards &DivGuards ) {
1552515597 // WARNING: It is generally unsound to apply any wrap flags to the proposed
1552615598 // replacement SCEV which isn't directly implied by the structure of that
1552715599 // SCEV. In particular, using contextual facts to imply flags is *NOT*
1552815600 // legal. See the scoping rules for flags in the header to understand why.
1552915601
15530- // If LHS is a constant, apply information to the other expression.
15531- if (isa<SCEVConstant>(LHS)) {
15532- std::swap(LHS, RHS);
15533- Predicate = CmpInst::getSwappedPredicate(Predicate);
15534- }
15535-
1553615602 // Check for a condition of the form (-C1 + X < C2). InstCombine will
1553715603 // create this form when combining two checks of the form (X u< C2 + C1) and
1553815604 // (X >=u C1).
@@ -15565,67 +15631,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1556515631 if (MatchRangeCheckIdiom())
1556615632 return;
1556715633
15568- // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15569- // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15570- // the non-constant operand and in \p LHS the constant operand.
15571- auto IsMinMaxSCEVWithNonNegativeConstant =
15572- [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15573- const SCEV *&RHS) {
15574- const APInt *C;
15575- SCTy = Expr->getSCEVType();
15576- return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
15577- match(LHS, m_scev_APInt(C)) && C->isNonNegative();
15578- };
15579-
15580- // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15581- // recursively. This is done by aligning up/down the constant value to the
15582- // Divisor.
15583- std::function<const SCEV *(const SCEV *, const SCEV *)>
15584- ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15585- const SCEV *Divisor) {
15586- auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15587- if (!ConstDivisor)
15588- return MinMaxExpr;
15589- const APInt &DivisorVal = ConstDivisor->getAPInt();
15590-
15591- const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15592- SCEVTypes SCTy;
15593- if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15594- MinMaxRHS))
15595- return MinMaxExpr;
15596- auto IsMin =
15597- isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15598- assert(SE.isKnownNonNegative(MinMaxLHS) &&
15599- "Expected non-negative operand!");
15600- auto *DivisibleExpr =
15601- IsMin
15602- ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
15603- : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
15604- SmallVector<const SCEV *> Ops = {
15605- ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15606- return SE.getMinMaxExpr(SCTy, Ops);
15607- };
15608-
15609- // If we have LHS == 0, check if LHS is computing a property of some unknown
15610- // SCEV %v which we can rewrite %v to express explicitly.
15611- if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15612- // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15613- // explicitly express that.
15614- const SCEVUnknown *URemLHS = nullptr;
15615- const SCEV *URemRHS = nullptr;
15616- if (match(LHS,
15617- m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15618- auto I = RewriteMap.find(URemLHS);
15619- const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15620- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15621- const auto *Multiple =
15622- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15623- RewriteMap[URemLHS] = Multiple;
15624- ExprsToRewrite.push_back(URemLHS);
15625- return;
15626- }
15627- }
15628-
1562915634 // Do not apply information for constants or if RHS contains an AddRec.
1563015635 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1563115636 return;
@@ -15655,7 +15660,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1565515660 };
1565615661
1565715662 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15658- const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
15663+ // Apply divisibility information when computing the constant multiple.
15664+ const APInt &DividesBy =
15665+ SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1565915666
1566015667 // Collect rewrites for LHS and its transitive operands based on the
1566115668 // condition.
@@ -15840,8 +15847,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1584015847
1584115848 // Now apply the information from the collected conditions to
1584215849 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15843- // earliest conditions is processed first. This ensures the SCEVs with the
15850+ // earliest conditions is processed first, except guards with divisibility
15851+ // information, which are moved to the back. This ensures the SCEVs with the
1584415852 // shortest dependency chains are constructed first.
15853+ SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15854+ GuardsToProcess;
1584515855 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1584615856 SmallVector<Value *, 8> Worklist;
1584715857 SmallPtrSet<Value *, 8> Visited;
@@ -15856,7 +15866,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1585615866 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1585715867 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1585815868 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15859- CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15869+ // If LHS is a constant, apply information to the other expression.
15870+ // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15871+ // can improve results.
15872+ if (isa<SCEVConstant>(LHS)) {
15873+ std::swap(LHS, RHS);
15874+ Predicate = CmpInst::getSwappedPredicate(Predicate);
15875+ }
15876+ GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1586015877 continue;
1586115878 }
1586215879
@@ -15869,6 +15886,31 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1586915886 }
1587015887 }
1587115888
15889+ // Process divisibility guards in reverse order to populate DivGuards early.
15890+ DenseMap<const SCEV *, APInt> Multiples;
15891+ LoopGuards DivGuards(SE);
15892+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15893+ if (!isDivisibilityGuard(LHS, RHS, SE))
15894+ continue;
15895+ collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15896+ Multiples, SE);
15897+ }
15898+
15899+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15900+ CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15901+
15902+ // Apply divisibility information last. This ensures it is applied to the
15903+ // outermost expression after other rewrites for the given value.
15904+ for (const auto &[K, Divisor] : Multiples) {
15905+ const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15906+ Guards.RewriteMap[K] =
15907+ SE.getMulExpr(SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(
15908+ Guards.rewrite(K), Divisor, SE),
15909+ DivisorSCEV),
15910+ DivisorSCEV);
15911+ ExprsToRewrite.push_back(K);
15912+ }
15913+
1587215914 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1587315915 // the replacement expressions are contained in the ranges of the replaced
1587415916 // expressions.
0 commit comments