@@ -5725,8 +5725,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
57255725 return true;
57265726
57275727 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5728- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5729- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5728+ if (Expr1 != Expr2 &&
5729+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5730+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
57305731 return false;
57315732 return true;
57325733 };
@@ -14823,7 +14824,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1482314824 bool addOverflowAssumption(const SCEVPredicate *P) {
1482414825 if (!NewPreds) {
1482514826 // Check if we've already made this assumption.
14826- return Pred && Pred->implies(P);
14827+ return Pred && Pred->implies(P, SE );
1482714828 }
1482814829 NewPreds->push_back(P);
1482914830 return true;
@@ -14904,7 +14905,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
1490414905 assert(LHS != RHS && "LHS and RHS are the same SCEV");
1490514906}
1490614907
14907- bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14908+ bool SCEVComparePredicate::implies(const SCEVPredicate *N,
14909+ ScalarEvolution &SE) const {
1490814910 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
1490914911
1491014912 if (!Op)
@@ -14934,10 +14936,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
1493414936
1493514937const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
1493614938
14937- bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14939+ bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
14940+ ScalarEvolution &SE) const {
1493814941 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14942+ if (!Op)
14943+ return false;
14944+
14945+ if (setFlags(Flags, Op->Flags) != Flags)
14946+ return false;
14947+
14948+ if (Op->AR == AR)
14949+ return true;
14950+
14951+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14952+ Flags != SCEVWrapPredicate::IncrementNUSW)
14953+ return false;
1493914954
14940- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14955+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14956+ const SCEV *Step = AR->getStepRecurrence(SE);
14957+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14958+
14959+ // If both steps are positive, this implies N, if N's start and step are
14960+ // ULE/SLE (for NSUW/NSSW) than this'.
14961+ if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
14962+ const SCEV *OpStart = Op->AR->getStart();
14963+ const SCEV *Start = AR->getStart();
14964+ if (SE.getTypeSizeInBits(Step->getType()) >
14965+ SE.getTypeSizeInBits(OpStep->getType())) {
14966+ OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
14967+ } else {
14968+ Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
14969+ : SE.getNoopOrSignExtend(Step, OpStep->getType());
14970+ }
14971+ if (SE.getTypeSizeInBits(Start->getType()) >
14972+ SE.getTypeSizeInBits(OpStart->getType())) {
14973+ OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
14974+ : SE.getSignExtendExpr(OpStart, Start->getType());
14975+ } else {
14976+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
14977+ : SE.getNoopOrSignExtend(Start, OpStart->getType());
14978+ }
14979+
14980+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
14981+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
14982+ SE.isKnownPredicate(Pred, OpStart, Start);
14983+ }
14984+ return false;
1494114985}
1494214986
1494314987bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -14981,48 +15025,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
1498115025}
1498215026
1498315027/// Union predicates don't get cached so create a dummy set ID for it.
14984- SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
14985- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15028+ SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
15029+ ScalarEvolution &SE)
15030+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
1498615031 for (const auto *P : Preds)
14987- add(P);
15032+ add(P, SE );
1498815033}
1498915034
1499015035bool SCEVUnionPredicate::isAlwaysTrue() const {
1499115036 return all_of(Preds,
1499215037 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
1499315038}
1499415039
14995- bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
15040+ bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
15041+ ScalarEvolution &SE) const {
1499615042 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14997- return all_of(Set->Preds,
14998- [this](const SCEVPredicate *I) { return this->implies(I); });
15043+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15044+ return this->implies(I, SE);
15045+ });
1499915046
1500015047 return any_of(Preds,
15001- [N](const SCEVPredicate *I) { return I->implies(N); });
15048+ [N, &SE ](const SCEVPredicate *I) { return I->implies(N, SE ); });
1500215049}
1500315050
1500415051void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
1500515052 for (const auto *Pred : Preds)
1500615053 Pred->print(OS, Depth);
1500715054}
1500815055
15009- void SCEVUnionPredicate::add(const SCEVPredicate *N) {
15056+ void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE ) {
1501015057 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
1501115058 for (const auto *Pred : Set->Preds)
15012- add(Pred);
15059+ add(Pred, SE );
1501315060 return;
1501415061 }
1501515062
1501615063 // Only add predicate if it is not already implied by this union predicate.
15017- if (!implies(N))
15064+ if (!implies(N, SE ))
1501815065 Preds.push_back(N);
1501915066}
1502015067
1502115068PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
1502215069 Loop &L)
1502315070 : SE(SE), L(L) {
1502415071 SmallVector<const SCEVPredicate*, 4> Empty;
15025- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
15072+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE );
1502615073}
1502715074
1502815075void ScalarEvolution::registerUser(const SCEV *User,
@@ -15086,12 +15133,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
1508615133}
1508715134
1508815135void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
15089- if (Preds->implies(&Pred))
15136+ if (Preds->implies(&Pred, SE ))
1509015137 return;
1509115138
1509215139 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
1509315140 NewPreds.push_back(&Pred);
15094- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
15141+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE );
1509515142 updateGeneration();
1509615143}
1509715144
@@ -15158,9 +15205,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
1515815205
1515915206PredicatedScalarEvolution::PredicatedScalarEvolution(
1516015207 const PredicatedScalarEvolution &Init)
15161- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15162- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15163- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15208+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15209+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15210+ SE)),
15211+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
1516415212 for (auto I : Init.FlagsMap)
1516515213 FlagsMap.insert(I);
1516615214}
0 commit comments