@@ -15505,6 +15505,79 @@ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
1550515505 return SE.getConstant(*ExprVal + DivisorVal - Rem);
1550615506}
1550715507
15508+ static bool collectDivisibilityInformation(
15509+ ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15510+ DenseMap<const SCEV *, const SCEV *> &DivInfo,
15511+ DenseMap<const SCEV *, APInt> &Multiples, ScalarEvolution &SE) {
15512+ // If we have LHS == 0, check if LHS is computing a property of some unknown
15513+ // SCEV %v which we can rewrite %v to express explicitly.
15514+ if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15515+ return false;
15516+ // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15517+ // explicitly express that.
15518+ const SCEVUnknown *URemLHS = nullptr;
15519+ const SCEV *URemRHS = nullptr;
15520+ if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15521+ return false;
15522+
15523+ const SCEV *Multiple =
15524+ SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15525+ DivInfo[URemLHS] = Multiple;
15526+ Multiples[URemLHS] = cast<SCEVConstant>(URemRHS)->getAPInt();
15527+ return true;
15528+ }
15529+
15530+ // Check if the condition is a divisibility guard (A % B == 0).
15531+ static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15532+ ScalarEvolution &SE) {
15533+ const SCEV *X, *Y;
15534+ return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15535+ }
15536+
15537+ // Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15538+ // recursively. This is done by aligning up/down the constant value to the
15539+ // Divisor.
15540+ static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15541+ const SCEV *Divisor,
15542+ ScalarEvolution &SE) {
15543+ // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15544+ // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15545+ // the non-constant operand and in \p LHS the constant operand.
15546+ auto IsMinMaxSCEVWithNonNegativeConstant =
15547+ [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15548+ const SCEV *&RHS) {
15549+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15550+ if (MinMax->getNumOperands() != 2)
15551+ return false;
15552+ if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15553+ if (C->getAPInt().isNegative())
15554+ return false;
15555+ SCTy = MinMax->getSCEVType();
15556+ LHS = MinMax->getOperand(0);
15557+ RHS = MinMax->getOperand(1);
15558+ return true;
15559+ }
15560+ }
15561+ return false;
15562+ };
15563+
15564+ const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15565+ SCEVTypes SCTy;
15566+ if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15567+ MinMaxRHS))
15568+ return MinMaxExpr;
15569+ auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15570+ assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15571+ auto *DivisibleExpr =
15572+ IsMin ? getPreviousSCEVDivisibleByDivisor(
15573+ MinMaxLHS, cast<SCEVConstant>(Divisor)->getAPInt(), SE)
15574+ : getNextSCEVDivisibleByDivisor(
15575+ MinMaxLHS, cast<SCEVConstant>(Divisor)->getAPInt(), SE);
15576+ SmallVector<const SCEV *> Ops = {
15577+ applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15578+ return SE.getMinMaxExpr(SCTy, Ops);
15579+ }
15580+
1550815581void ScalarEvolution::LoopGuards::collectFromBlock(
1550915582 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1551015583 const BasicBlock *Block, const BasicBlock *Pred,
@@ -15515,19 +15588,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1551515588 SmallVector<const SCEV *> ExprsToRewrite;
1551615589 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1551715590 const SCEV *RHS,
15518- DenseMap<const SCEV *, const SCEV *>
15519- &RewriteMap) {
15591+ DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15592+ const DenseMap<const SCEV *, const SCEV *>
15593+ &DivInfo) {
1552015594 // WARNING: It is generally unsound to apply any wrap flags to the proposed
1552115595 // replacement SCEV which isn't directly implied by the structure of that
1552215596 // SCEV. In particular, using contextual facts to imply flags is *NOT*
1552315597 // legal. See the scoping rules for flags in the header to understand why.
1552415598
15525- // If LHS is a constant, apply information to the other expression.
15526- if (isa<SCEVConstant>(LHS)) {
15527- std::swap(LHS, RHS);
15528- Predicate = CmpInst::getSwappedPredicate(Predicate);
15529- }
15530-
1553115599 // Check for a condition of the form (-C1 + X < C2). InstCombine will
1553215600 // create this form when combining two checks of the form (X u< C2 + C1) and
1553315601 // (X >=u C1).
@@ -15560,67 +15628,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1556015628 if (MatchRangeCheckIdiom())
1556115629 return;
1556215630
15563- // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15564- // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15565- // the non-constant operand and in \p LHS the constant operand.
15566- auto IsMinMaxSCEVWithNonNegativeConstant =
15567- [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15568- const SCEV *&RHS) {
15569- const APInt *C;
15570- SCTy = Expr->getSCEVType();
15571- return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
15572- match(LHS, m_scev_APInt(C)) && C->isNonNegative();
15573- };
15574-
15575- // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15576- // recursively. This is done by aligning up/down the constant value to the
15577- // Divisor.
15578- std::function<const SCEV *(const SCEV *, const SCEV *)>
15579- ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15580- const SCEV *Divisor) {
15581- auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15582- if (!ConstDivisor)
15583- return MinMaxExpr;
15584- const APInt &DivisorVal = ConstDivisor->getAPInt();
15585-
15586- const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15587- SCEVTypes SCTy;
15588- if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15589- MinMaxRHS))
15590- return MinMaxExpr;
15591- auto IsMin =
15592- isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15593- assert(SE.isKnownNonNegative(MinMaxLHS) &&
15594- "Expected non-negative operand!");
15595- auto *DivisibleExpr =
15596- IsMin
15597- ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
15598- : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
15599- SmallVector<const SCEV *> Ops = {
15600- ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15601- return SE.getMinMaxExpr(SCTy, Ops);
15602- };
15603-
15604- // If we have LHS == 0, check if LHS is computing a property of some unknown
15605- // SCEV %v which we can rewrite %v to express explicitly.
15606- if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15607- // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15608- // explicitly express that.
15609- const SCEVUnknown *URemLHS = nullptr;
15610- const SCEV *URemRHS = nullptr;
15611- if (match(LHS,
15612- m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15613- auto I = RewriteMap.find(URemLHS);
15614- const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15615- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15616- const auto *Multiple =
15617- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15618- RewriteMap[URemLHS] = Multiple;
15619- ExprsToRewrite.push_back(URemLHS);
15620- return;
15621- }
15622- }
15623-
1562415631 // Do not apply information for constants or if RHS contains an AddRec.
1562515632 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1562615633 return;
@@ -15650,7 +15657,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1565015657 };
1565115658
1565215659 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15653- const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
15660+ // Apply divisibility information when computing the constant multiple.
15661+ LoopGuards DivGuards(SE);
15662+ DivGuards.RewriteMap = DivInfo;
15663+ const APInt &DividesBy =
15664+ SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1565415665
1565515666 // Collect rewrites for LHS and its transitive operands based on the
1565615667 // condition.
@@ -15835,8 +15846,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1583515846
1583615847 // Now apply the information from the collected conditions to
1583715848 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15838- // earliest conditions is processed first. This ensures the SCEVs with the
15849+ // earliest conditions is processed first, except guards with divisibility
15850+ // information, which are moved to the back. This ensures the SCEVs with the
1583915851 // shortest dependency chains are constructed first.
15852+ SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15853+ GuardsToProcess;
1584015854 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1584115855 SmallVector<Value *, 8> Worklist;
1584215856 SmallPtrSet<Value *, 8> Visited;
@@ -15851,7 +15865,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1585115865 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1585215866 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1585315867 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15854- CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15868+ // If LHS is a constant, apply information to the other expression.
15869+ if (isa<SCEVConstant>(LHS)) {
15870+ std::swap(LHS, RHS);
15871+ Predicate = CmpInst::getSwappedPredicate(Predicate);
15872+ }
15873+ GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1585515874 continue;
1585615875 }
1585715876
@@ -15864,6 +15883,30 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1586415883 }
1586515884 }
1586615885
15886+ // Process divisibility guards in reverse order to populate DivInfo early.
15887+ DenseMap<const SCEV *, APInt> Multiples;
15888+ DenseMap<const SCEV *, const SCEV *> DivInfo;
15889+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15890+ if (!isDivisibilityGuard(LHS, RHS, SE))
15891+ continue;
15892+ collectDivisibilityInformation(Predicate, LHS, RHS, DivInfo, Multiples, SE);
15893+ }
15894+
15895+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15896+ CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivInfo);
15897+
15898+ // Apply divisibility information last. This ensures it is applied to the
15899+ // outermost expression after other rewrites for the given value.
15900+ for (const auto &[K, V] : Multiples) {
15901+ const SCEV *DivisorSCEV = SE.getConstant(V);
15902+ Guards.RewriteMap[K] =
15903+ SE.getMulExpr(SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(
15904+ Guards.rewrite(K), DivisorSCEV, SE),
15905+ DivisorSCEV),
15906+ DivisorSCEV);
15907+ ExprsToRewrite.push_back(K);
15908+ }
15909+
1586715910 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1586815911 // the replacement expressions are contained in the ranges of the replaced
1586915912 // expressions.
0 commit comments