@@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
57065706 return true;
57075707
57085708 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5710- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5709+ if (Expr1 != Expr2 &&
5710+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
57115712 return false;
57125713 return true;
57135714 };
@@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1485714858 bool addOverflowAssumption(const SCEVPredicate *P) {
1485814859 if (!NewPreds) {
1485914860 // Check if we've already made this assumption.
14860- return Pred && Pred->implies(P);
14861+ return Pred && Pred->implies(P, SE );
1486114862 }
1486214863 NewPreds->push_back(P);
1486314864 return true;
@@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
1493814939 assert(LHS != RHS && "LHS and RHS are the same SCEV");
1493914940}
1494014941
14941- bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14942+ bool SCEVComparePredicate::implies(const SCEVPredicate *N,
14943+ ScalarEvolution &SE) const {
1494214944 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
1494314945
1494414946 if (!Op)
@@ -14968,10 +14970,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
1496814970
1496914971const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
1497014972
14971- bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14973+ bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
14974+ ScalarEvolution &SE) const {
1497214975 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14976+ if (!Op)
14977+ return false;
14978+
14979+ if (setFlags(Flags, Op->Flags) != Flags)
14980+ return false;
14981+
14982+ if (Op->AR == AR)
14983+ return true;
14984+
14985+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14986+ Flags != SCEVWrapPredicate::IncrementNUSW)
14987+ return false;
1497314988
14974- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14989+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14990+ const SCEV *Step = AR->getStepRecurrence(SE);
14991+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14992+
14993+ // If both steps are positive, this implies N, if N's start and step are
14994+ // ULE/SLE (for NSUW/NSSW) than this'.
14995+ if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
14996+ const SCEV *OpStart = Op->AR->getStart();
14997+ const SCEV *Start = AR->getStart();
14998+ if (SE.getTypeSizeInBits(Step->getType()) >
14999+ SE.getTypeSizeInBits(OpStep->getType())) {
15000+ OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
15001+ } else {
15002+ Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
15003+ : SE.getNoopOrSignExtend(Step, OpStep->getType());
15004+ }
15005+ if (SE.getTypeSizeInBits(Start->getType()) >
15006+ SE.getTypeSizeInBits(OpStart->getType())) {
15007+ OpStart = IsNUW ? SE.getZeroExtendExpr(OpStart, Start->getType())
15008+ : SE.getSignExtendExpr(OpStart, Start->getType());
15009+ } else {
15010+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, OpStart->getType())
15011+ : SE.getNoopOrSignExtend(Start, OpStart->getType());
15012+ }
15013+
15014+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
15015+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
15016+ SE.isKnownPredicate(Pred, OpStart, Start);
15017+ }
15018+ return false;
1497515019}
1497615020
1497715021bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -15015,48 +15059,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
1501515059}
1501615060
1501715061/// Union predicates don't get cached so create a dummy set ID for it.
15018- SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
15019- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15062+ SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
15063+ ScalarEvolution &SE)
15064+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
1502015065 for (const auto *P : Preds)
15021- add(P);
15066+ add(P, SE );
1502215067}
1502315068
1502415069bool SCEVUnionPredicate::isAlwaysTrue() const {
1502515070 return all_of(Preds,
1502615071 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
1502715072}
1502815073
15029- bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
15074+ bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
15075+ ScalarEvolution &SE) const {
1503015076 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15031- return all_of(Set->Preds,
15032- [this](const SCEVPredicate *I) { return this->implies(I); });
15077+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15078+ return this->implies(I, SE);
15079+ });
1503315080
1503415081 return any_of(Preds,
15035- [N](const SCEVPredicate *I) { return I->implies(N); });
15082+ [N, &SE ](const SCEVPredicate *I) { return I->implies(N, SE ); });
1503615083}
1503715084
1503815085void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
1503915086 for (const auto *Pred : Preds)
1504015087 Pred->print(OS, Depth);
1504115088}
1504215089
15043- void SCEVUnionPredicate::add(const SCEVPredicate *N) {
15090+ void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE ) {
1504415091 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
1504515092 for (const auto *Pred : Set->Preds)
15046- add(Pred);
15093+ add(Pred, SE );
1504715094 return;
1504815095 }
1504915096
1505015097 // Only add predicate if it is not already implied by this union predicate.
15051- if (!implies(N))
15098+ if (!implies(N, SE ))
1505215099 Preds.push_back(N);
1505315100}
1505415101
1505515102PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
1505615103 Loop &L)
1505715104 : SE(SE), L(L) {
1505815105 SmallVector<const SCEVPredicate*, 4> Empty;
15059- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
15106+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE );
1506015107}
1506115108
1506215109void ScalarEvolution::registerUser(const SCEV *User,
@@ -15120,12 +15167,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
1512015167}
1512115168
1512215169void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
15123- if (Preds->implies(&Pred))
15170+ if (Preds->implies(&Pred, SE ))
1512415171 return;
1512515172
1512615173 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
1512715174 NewPreds.push_back(&Pred);
15128- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
15175+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE );
1512915176 updateGeneration();
1513015177}
1513115178
@@ -15192,9 +15239,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
1519215239
1519315240PredicatedScalarEvolution::PredicatedScalarEvolution(
1519415241 const PredicatedScalarEvolution &Init)
15195- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15196- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15197- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15242+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15243+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15244+ SE)),
15245+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
1519815246 for (auto I : Init.FlagsMap)
1519915247 FlagsMap.insert(I);
1520015248}
0 commit comments