@@ -15474,11 +15474,26 @@ void ScalarEvolution::LoopGuards::collectFromPHI(
1547415474}
1547515475
1547615476// Return a new SCEV that modifies \p Expr to the closest number divides by
15477- // \p Divisor and greater or equal than Expr. For now, only handle constant
15477+ // \p Divisor and less or equal than Expr. For now, only handle constant
1547815478// Expr.
15479- static const SCEV *getNextSCEVDividesByDivisor(const SCEV *Expr,
15480- const APInt &DivisorVal,
15481- ScalarEvolution &SE) {
15479+ static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr,
15480+ const APInt &DivisorVal,
15481+ ScalarEvolution &SE) {
15482+ const APInt *ExprVal;
15483+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15484+ DivisorVal.isNonPositive())
15485+ return Expr;
15486+ APInt Rem = ExprVal->urem(DivisorVal);
15487+ // return the SCEV: Expr - Expr % Divisor
15488+ return SE.getConstant(*ExprVal - Rem);
15489+ }
15490+
15491+ // Return a new SCEV that modifies \p Expr to the closest number divisible by
15492+ // \p Divisor and greater than or equal to Expr.
15493+ // For now, only handle constant Expr and Divisor.
15494+ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15495+ const APInt &DivisorVal,
15496+ ScalarEvolution &SE) {
1548215497 const APInt *ExprVal;
1548315498 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
1548415499 DivisorVal.isNonPositive())
@@ -15490,6 +15505,79 @@ static const SCEV *getNextSCEVDividesByDivisor(const SCEV *Expr,
1549015505 return SE.getConstant(*ExprVal + DivisorVal - Rem);
1549115506}
1549215507
15508+ static bool collectDivisibilityInformation(
15509+ ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15510+ DenseMap<const SCEV *, const SCEV *> &DivInfo,
15511+ DenseMap<const SCEV *, APInt> &Multiples, ScalarEvolution &SE) {
15512+ // If we have LHS == 0, check if LHS is computing a property of some unknown
15513+ // SCEV %v which we can rewrite %v to express explicitly.
15514+ if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15515+ return false;
15516+ // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15517+ // explicitly express that.
15518+ const SCEVUnknown *URemLHS = nullptr;
15519+ const SCEV *URemRHS = nullptr;
15520+ if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15521+ return false;
15522+
15523+ const SCEV *Multiple =
15524+ SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15525+ DivInfo[URemLHS] = Multiple;
15526+ Multiples[URemLHS] = cast<SCEVConstant>(URemRHS)->getAPInt();
15527+ return true;
15528+ }
15529+
15530+ // Check if the condition is a divisibility guard (A % B == 0).
15531+ static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15532+ ScalarEvolution &SE) {
15533+ const SCEV *X, *Y;
15534+ return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15535+ }
15536+
15537+ // Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15538+ // recursively. This is done by aligning up/down the constant value to the
15539+ // Divisor.
15540+ static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15541+ const SCEV *Divisor,
15542+ ScalarEvolution &SE) {
15543+ // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15544+ // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15545+ // the non-constant operand and in \p LHS the constant operand.
15546+ auto IsMinMaxSCEVWithNonNegativeConstant =
15547+ [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15548+ const SCEV *&RHS) {
15549+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15550+ if (MinMax->getNumOperands() != 2)
15551+ return false;
15552+ if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15553+ if (C->getAPInt().isNegative())
15554+ return false;
15555+ SCTy = MinMax->getSCEVType();
15556+ LHS = MinMax->getOperand(0);
15557+ RHS = MinMax->getOperand(1);
15558+ return true;
15559+ }
15560+ }
15561+ return false;
15562+ };
15563+
15564+ const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15565+ SCEVTypes SCTy;
15566+ if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15567+ MinMaxRHS))
15568+ return MinMaxExpr;
15569+ auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15570+ assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15571+ auto *DivisibleExpr =
15572+ IsMin ? getPreviousSCEVDivisibleByDivisor(
15573+ MinMaxLHS, cast<SCEVConstant>(Divisor)->getAPInt(), SE)
15574+ : getNextSCEVDivisibleByDivisor(
15575+ MinMaxLHS, cast<SCEVConstant>(Divisor)->getAPInt(), SE);
15576+ SmallVector<const SCEV *> Ops = {
15577+ applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15578+ return SE.getMinMaxExpr(SCTy, Ops);
15579+ }
15580+
1549315581void ScalarEvolution::LoopGuards::collectFromBlock(
1549415582 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1549515583 const BasicBlock *Block, const BasicBlock *Pred,
@@ -15500,19 +15588,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1550015588 SmallVector<const SCEV *> ExprsToRewrite;
1550115589 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1550215590 const SCEV *RHS,
15503- DenseMap<const SCEV *, const SCEV *>
15504- &RewriteMap) {
15591+ DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15592+ const DenseMap<const SCEV *, const SCEV *>
15593+ &DivInfo) {
1550515594 // WARNING: It is generally unsound to apply any wrap flags to the proposed
1550615595 // replacement SCEV which isn't directly implied by the structure of that
1550715596 // SCEV. In particular, using contextual facts to imply flags is *NOT*
1550815597 // legal. See the scoping rules for flags in the header to understand why.
1550915598
15510- // If LHS is a constant, apply information to the other expression.
15511- if (isa<SCEVConstant>(LHS)) {
15512- std::swap(LHS, RHS);
15513- Predicate = CmpInst::getSwappedPredicate(Predicate);
15514- }
15515-
1551615599 // Check for a condition of the form (-C1 + X < C2). InstCombine will
1551715600 // create this form when combining two checks of the form (X u< C2 + C1) and
1551815601 // (X >=u C1).
@@ -15545,80 +15628,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1554515628 if (MatchRangeCheckIdiom())
1554615629 return;
1554715630
15548- // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15549- // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15550- // the non-constant operand and in \p LHS the constant operand.
15551- auto IsMinMaxSCEVWithNonNegativeConstant =
15552- [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15553- const SCEV *&RHS) {
15554- const APInt *C;
15555- SCTy = Expr->getSCEVType();
15556- return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
15557- match(LHS, m_scev_APInt(C)) && C->isNonNegative();
15558- };
15559-
15560- // Return a new SCEV that modifies \p Expr to the closest number divides by
15561- // \p Divisor and less or equal than Expr. For now, only handle constant
15562- // Expr.
15563- auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15564- const APInt &DivisorVal) {
15565- const APInt *ExprVal;
15566- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15567- DivisorVal.isNonPositive())
15568- return Expr;
15569- APInt Rem = ExprVal->urem(DivisorVal);
15570- // return the SCEV: Expr - Expr % Divisor
15571- return SE.getConstant(*ExprVal - Rem);
15572- };
15573-
15574- // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15575- // recursively. This is done by aligning up/down the constant value to the
15576- // Divisor.
15577- std::function<const SCEV *(const SCEV *, const SCEV *)>
15578- ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15579- const SCEV *Divisor) {
15580- auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15581- if (!ConstDivisor)
15582- return MinMaxExpr;
15583- const APInt &DivisorVal = ConstDivisor->getAPInt();
15584-
15585- const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15586- SCEVTypes SCTy;
15587- if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15588- MinMaxRHS))
15589- return MinMaxExpr;
15590- auto IsMin =
15591- isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15592- assert(SE.isKnownNonNegative(MinMaxLHS) &&
15593- "Expected non-negative operand!");
15594- auto *DivisibleExpr =
15595- IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal)
15596- : getNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal, SE);
15597- SmallVector<const SCEV *> Ops = {
15598- ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15599- return SE.getMinMaxExpr(SCTy, Ops);
15600- };
15601-
15602- // If we have LHS == 0, check if LHS is computing a property of some unknown
15603- // SCEV %v which we can rewrite %v to express explicitly.
15604- if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15605- // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15606- // explicitly express that.
15607- const SCEVUnknown *URemLHS = nullptr;
15608- const SCEV *URemRHS = nullptr;
15609- if (match(LHS,
15610- m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15611- auto I = RewriteMap.find(URemLHS);
15612- const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15613- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15614- const auto *Multiple =
15615- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15616- RewriteMap[URemLHS] = Multiple;
15617- ExprsToRewrite.push_back(URemLHS);
15618- return;
15619- }
15620- }
15621-
1562215631 // Do not apply information for constants or if RHS contains an AddRec.
1562315632 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1562415633 return;
@@ -15648,7 +15657,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1564815657 };
1564915658
1565015659 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15651- const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
15660+ // Apply divisibility information when computing the constant multiple.
15661+ LoopGuards DivGuards(SE);
15662+ DivGuards.RewriteMap = DivInfo;
15663+ const APInt &DividesBy =
15664+ SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1565215665
1565315666 // Collect rewrites for LHS and its transitive operands based on the
1565415667 // condition.
@@ -15663,31 +15676,31 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1566315676 // predicate.
1566415677 const SCEV *One = SE.getOne(RHS->getType());
1566515678 switch (Predicate) {
15666- case CmpInst::ICMP_ULT:
15667- if (RHS->getType()->isPointerTy())
15668- return;
15669- RHS = SE.getUMaxExpr(RHS, One);
15670- [[fallthrough]];
15671- case CmpInst::ICMP_SLT: {
15672- RHS = SE.getMinusSCEV(RHS, One);
15673- RHS = GetPreviousSCEVDividesByDivisor (RHS, DividesBy);
15674- break;
15675- }
15676- case CmpInst::ICMP_UGT:
15677- case CmpInst::ICMP_SGT:
15678- RHS = SE.getAddExpr(RHS, One);
15679- RHS = getNextSCEVDividesByDivisor (RHS, DividesBy, SE);
15680- break;
15681- case CmpInst::ICMP_ULE:
15682- case CmpInst::ICMP_SLE:
15683- RHS = GetPreviousSCEVDividesByDivisor (RHS, DividesBy);
15684- break;
15685- case CmpInst::ICMP_UGE:
15686- case CmpInst::ICMP_SGE:
15687- RHS = getNextSCEVDividesByDivisor (RHS, DividesBy, SE);
15688- break;
15689- default:
15690- break;
15679+ case CmpInst::ICMP_ULT:
15680+ if (RHS->getType()->isPointerTy())
15681+ return;
15682+ RHS = SE.getUMaxExpr(RHS, One);
15683+ [[fallthrough]];
15684+ case CmpInst::ICMP_SLT: {
15685+ RHS = SE.getMinusSCEV(RHS, One);
15686+ RHS = getPreviousSCEVDivisibleByDivisor (RHS, DividesBy, SE );
15687+ break;
15688+ }
15689+ case CmpInst::ICMP_UGT:
15690+ case CmpInst::ICMP_SGT:
15691+ RHS = SE.getAddExpr(RHS, One);
15692+ RHS = getNextSCEVDivisibleByDivisor (RHS, DividesBy, SE);
15693+ break;
15694+ case CmpInst::ICMP_ULE:
15695+ case CmpInst::ICMP_SLE:
15696+ RHS = getPreviousSCEVDivisibleByDivisor (RHS, DividesBy, SE );
15697+ break;
15698+ case CmpInst::ICMP_UGE:
15699+ case CmpInst::ICMP_SGE:
15700+ RHS = getNextSCEVDivisibleByDivisor (RHS, DividesBy, SE);
15701+ break;
15702+ default:
15703+ break;
1569115704 }
1569215705
1569315706 SmallVector<const SCEV *, 16> Worklist(1, LHS);
@@ -15738,7 +15751,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1573815751 case CmpInst::ICMP_NE:
1573915752 if (match(RHS, m_scev_Zero())) {
1574015753 const SCEV *OneAlignedUp =
15741- getNextSCEVDividesByDivisor (One, DividesBy, SE);
15754+ getNextSCEVDivisibleByDivisor (One, DividesBy, SE);
1574215755 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
1574315756 } else {
1574415757 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
@@ -15833,8 +15846,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1583315846
1583415847 // Now apply the information from the collected conditions to
1583515848 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15836- // earliest conditions is processed first. This ensures the SCEVs with the
15849+ // earliest conditions is processed first, except guards with divisibility
15850+ // information, which are moved to the back. This ensures the SCEVs with the
1583715851 // shortest dependency chains are constructed first.
15852+ SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15853+ GuardsToProcess;
1583815854 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1583915855 SmallVector<Value *, 8> Worklist;
1584015856 SmallPtrSet<Value *, 8> Visited;
@@ -15849,7 +15865,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1584915865 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1585015866 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1585115867 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15852- CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15868+ // If LHS is a constant, apply information to the other expression.
15869+ if (isa<SCEVConstant>(LHS)) {
15870+ std::swap(LHS, RHS);
15871+ Predicate = CmpInst::getSwappedPredicate(Predicate);
15872+ }
15873+ GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1585315874 continue;
1585415875 }
1585515876
@@ -15862,6 +15883,30 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1586215883 }
1586315884 }
1586415885
15886+ // Process divisibility guards in reverse order to populate DivInfo early.
15887+ DenseMap<const SCEV *, APInt> Multiples;
15888+ DenseMap<const SCEV *, const SCEV *> DivInfo;
15889+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15890+ if (!isDivisibilityGuard(LHS, RHS, SE))
15891+ continue;
15892+ collectDivisibilityInformation(Predicate, LHS, RHS, DivInfo, Multiples, SE);
15893+ }
15894+
15895+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15896+ CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivInfo);
15897+
15898+ // Apply divisibility information last. This ensures it is applied to the
15899+ // outermost expression after other rewrites for the given value.
15900+ for (const auto &[K, V] : Multiples) {
15901+ const SCEV *DivisorSCEV = SE.getConstant(V);
15902+ Guards.RewriteMap[K] =
15903+ SE.getMulExpr(SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(
15904+ Guards.rewrite(K), DivisorSCEV, SE),
15905+ DivisorSCEV),
15906+ DivisorSCEV);
15907+ ExprsToRewrite.push_back(K);
15908+ }
15909+
1586515910 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1586615911 // the replacement expressions are contained in the ranges of the replaced
1586715912 // expressions.
@@ -15965,7 +16010,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1596516010 if (LHS > RHS)
1596616011 std::swap(LHS, RHS);
1596716012 if (NotEqual.contains({LHS, RHS})) {
15968- const SCEV *OneAlignedUp = getNextSCEVDividesByDivisor (
16013+ const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor (
1596916014 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
1597016015 return SE.getUMaxExpr(OneAlignedUp, S);
1597116016 }
0 commit comments