@@ -5706,8 +5706,9 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
57065706 return true;
57075707
57085708 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5709- if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5710- !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5709+ if (Expr1 != Expr2 &&
5710+ !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5711+ !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
57115712 return false;
57125713 return true;
57135714 };
@@ -14857,7 +14858,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
1485714858 bool addOverflowAssumption(const SCEVPredicate *P) {
1485814859 if (!NewPreds) {
1485914860 // Check if we've already made this assumption.
14860- return Pred && Pred->implies(P);
14861+ return Pred && Pred->implies(P, SE );
1486114862 }
1486214863 NewPreds->push_back(P);
1486314864 return true;
@@ -14938,7 +14939,8 @@ SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID,
1493814939 assert(LHS != RHS && "LHS and RHS are the same SCEV");
1493914940}
1494014941
14941- bool SCEVComparePredicate::implies(const SCEVPredicate *N) const {
14942+ bool SCEVComparePredicate::implies(const SCEVPredicate *N,
14943+ ScalarEvolution &SE) const {
1494214944 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
1494314945
1494414946 if (!Op)
@@ -14968,10 +14970,40 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
1496814970
1496914971const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
1497014972
14971- bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
14973+ bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
14974+ ScalarEvolution &SE) const {
1497214975 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14976+ if (!Op || setFlags(Flags, Op->Flags) != Flags)
14977+ return false;
14978+
14979+ if (Op->AR == AR)
14980+ return true;
14981+
14982+ if (Flags != SCEVWrapPredicate::IncrementNSSW &&
14983+ Flags != SCEVWrapPredicate::IncrementNUSW)
14984+ return false;
1497314985
14974- return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14986+ const SCEV *Step = AR->getStepRecurrence(SE);
14987+ const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
14988+ if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
14989+ return false;
14990+
14991+ // If both steps are positive, this implies N, if N's start and step are
14992+ // ULE/SLE (for NSUW/NSSW) than this'.
14993+ Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
14994+ Step = SE.getNoopOrZeroExtend(Step, WiderTy);
14995+ OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
14996+
14997+ bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
14998+ const SCEV *OpStart = Op->AR->getStart();
14999+ const SCEV *Start = AR->getStart();
15000+ OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15001+ : SE.getNoopOrSignExtend(OpStart, WiderTy);
15002+ Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15003+ : SE.getNoopOrSignExtend(Start, WiderTy);
15004+ CmpInst::Predicate Pred = IsNUW ? CmpInst::ICMP_ULE : CmpInst::ICMP_SLE;
15005+ return SE.isKnownPredicate(Pred, OpStep, Step) &&
15006+ SE.isKnownPredicate(Pred, OpStart, Start);
1497515007}
1497615008
1497715009bool SCEVWrapPredicate::isAlwaysTrue() const {
@@ -15015,48 +15047,51 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
1501515047}
1501615048
1501715049/// Union predicates don't get cached so create a dummy set ID for it.
15018- SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
15019- : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15050+ SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds,
15051+ ScalarEvolution &SE)
15052+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
1502015053 for (const auto *P : Preds)
15021- add(P);
15054+ add(P, SE );
1502215055}
1502315056
1502415057bool SCEVUnionPredicate::isAlwaysTrue() const {
1502515058 return all_of(Preds,
1502615059 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
1502715060}
1502815061
15029- bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
15062+ bool SCEVUnionPredicate::implies(const SCEVPredicate *N,
15063+ ScalarEvolution &SE) const {
1503015064 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15031- return all_of(Set->Preds,
15032- [this](const SCEVPredicate *I) { return this->implies(I); });
15065+ return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15066+ return this->implies(I, SE);
15067+ });
1503315068
1503415069 return any_of(Preds,
15035- [N](const SCEVPredicate *I) { return I->implies(N); });
15070+ [N, &SE ](const SCEVPredicate *I) { return I->implies(N, SE ); });
1503615071}
1503715072
1503815073void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
1503915074 for (const auto *Pred : Preds)
1504015075 Pred->print(OS, Depth);
1504115076}
1504215077
15043- void SCEVUnionPredicate::add(const SCEVPredicate *N) {
15078+ void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE ) {
1504415079 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
1504515080 for (const auto *Pred : Set->Preds)
15046- add(Pred);
15081+ add(Pred, SE );
1504715082 return;
1504815083 }
1504915084
1505015085 // Only add predicate if it is not already implied by this union predicate.
15051- if (!implies(N))
15086+ if (!implies(N, SE ))
1505215087 Preds.push_back(N);
1505315088}
1505415089
1505515090PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
1505615091 Loop &L)
1505715092 : SE(SE), L(L) {
1505815093 SmallVector<const SCEVPredicate*, 4> Empty;
15059- Preds = std::make_unique<SCEVUnionPredicate>(Empty);
15094+ Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE );
1506015095}
1506115096
1506215097void ScalarEvolution::registerUser(const SCEV *User,
@@ -15120,12 +15155,12 @@ unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
1512015155}
1512115156
1512215157void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
15123- if (Preds->implies(&Pred))
15158+ if (Preds->implies(&Pred, SE ))
1512415159 return;
1512515160
1512615161 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
1512715162 NewPreds.push_back(&Pred);
15128- Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
15163+ Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE );
1512915164 updateGeneration();
1513015165}
1513115166
@@ -15192,9 +15227,10 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
1519215227
1519315228PredicatedScalarEvolution::PredicatedScalarEvolution(
1519415229 const PredicatedScalarEvolution &Init)
15195- : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15196- Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
15197- Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15230+ : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15231+ Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15232+ SE)),
15233+ Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
1519815234 for (auto I : Init.FlagsMap)
1519915235 FlagsMap.insert(I);
1520015236}
0 commit comments