@@ -253,6 +253,59 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
253253// SCEV class definitions
254254//===----------------------------------------------------------------------===//
255255
256+ class SCEVDropFlags : public SCEVRewriteVisitor<SCEVDropFlags> {
257+ using Base = SCEVRewriteVisitor<SCEVDropFlags>;
258+
259+ public:
260+ SCEVDropFlags(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
261+
262+ static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE) {
263+ SCEVDropFlags Rewriter(SE);
264+ return Rewriter.visit(Scev);
265+ }
266+
267+ SCEVUse visitAddExpr(const SCEVAddExpr *Expr) {
268+ SmallVector<const SCEV *, 2> Operands;
269+ bool Changed = false;
270+ for (const auto Op : Expr->operands()) {
271+ Operands.push_back(visit(Op));
272+ Changed |= Op != Operands.back();
273+ }
274+ return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
275+ }
276+
277+ SCEVUse visitMulExpr(const SCEVMulExpr *Expr) {
278+ SmallVector<SCEVUse, 2> Operands;
279+ bool Changed = false;
280+ for (const auto Op : Expr->operands()) {
281+ Operands.push_back(visit(Op));
282+ Changed |= Op != Operands.back();
283+ }
284+ return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
285+ }
286+ };
287+
288+ const SCEV *SCEVUse::computeCanonical(ScalarEvolution &SE) const {
289+ return SCEVDropFlags::rewrite(*this, SE);
290+ }
291+
292+ bool SCEVUse::computeIsCanonical() const {
293+ if (!getRawPointer() ||
294+ DenseMapInfo<SCEVUse>::getEmptyKey().getRawPointer() == getRawPointer() ||
295+ DenseMapInfo<SCEVUse>::getTombstoneKey().getRawPointer() ==
296+ getRawPointer() ||
297+ isa<SCEVCouldNotCompute>(this))
298+ return true;
299+ return !SCEVExprContains(*this, [](SCEVUse U) { return U.getFlags() != 0; });
300+ }
301+
302+ bool SCEVUse::operator==(const SCEVUse &RHS) const {
303+ assert(isCanonical() && RHS.isCanonical());
304+ return getPointer() == RHS.getPointer();
305+ }
306+
307+ bool SCEVUse::operator==(const SCEV *RHS) const { return getPointer() == RHS; }
308+
256309#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
257310LLVM_DUMP_METHOD void SCEVUse::dump() const {
258311 print(dbgs());
@@ -677,9 +730,10 @@ static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
677730static std::optional<int>
678731CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
679732 const LoopInfo *const LI, SCEVUse LHS, SCEVUse RHS,
680- DominatorTree &DT, unsigned Depth = 0) {
733+ DominatorTree &DT, ScalarEvolution &SE,
734+ unsigned Depth = 0) {
681735 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
682- if (LHS == RHS)
736+ if (LHS.getCanonical(SE) == RHS.getCanonical(SE) )
683737 return 0;
684738
685739 // Primarily, sort the SCEVs by their getSCEVType().
@@ -769,7 +823,7 @@ CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
769823 return (int)LNumOps - (int)RNumOps;
770824
771825 for (unsigned i = 0; i != LNumOps; ++i) {
772- auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
826+ auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT, SE,
773827 Depth + 1);
774828 if (X != 0)
775829 return X;
@@ -794,14 +848,14 @@ CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
794848/// this to depend on where the addresses of various SCEV objects happened to
795849/// land in memory.
796850static void GroupByComplexity(SmallVectorImpl<SCEVUse> &Ops, LoopInfo *LI,
797- DominatorTree &DT) {
851+ DominatorTree &DT, ScalarEvolution &SE ) {
798852 if (Ops.size() < 2) return; // Noop
799853
800854 EquivalenceClasses<SCEVUse> EqCacheSCEV;
801855
802856 // Whether LHS has provably less complexity than RHS.
803857 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
804- auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
858+ auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT, SE );
805859 return Complexity && *Complexity < 0;
806860 };
807861 if (Ops.size() == 2) {
@@ -882,7 +936,7 @@ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
882936 if (Folded && IsAbsorber(Folded->getAPInt()))
883937 return Folded;
884938
885- GroupByComplexity(Ops, &LI, DT);
939+ GroupByComplexity(Ops, &LI, DT, SE );
886940 if (Folded && !IsIdentity(Folded->getAPInt()))
887941 Ops.insert(Ops.begin(), Folded);
888942
@@ -2586,7 +2640,9 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
25862640 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
25872641 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
25882642 Add->setNoWrapFlags(ComputeFlags(Ops));
2589- return S;
2643+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
2644+ int UseFlags = IsCanonical ? 0 : 1;
2645+ return {S, UseFlags};
25902646 }
25912647
25922648 // Okay, check to see if the same value occurs in the operand list more than
@@ -2595,7 +2651,8 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
25952651 Type *Ty = Ops[0]->getType();
25962652 bool FoundMatch = false;
25972653 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2598- if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2654+ if (Ops[i].getCanonical(*this) ==
2655+ Ops[i + 1].getCanonical(*this)) { // X + Y + Y --> X + Y*2
25992656 // Scan ahead to count how many equal operands there are.
26002657 unsigned Count = 2;
26012658 while (i+Count != e && Ops[i+Count] == Ops[i])
@@ -2817,7 +2874,7 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
28172874 if (isa<SCEVConstant>(MulOpSCEV))
28182875 continue;
28192876 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2820- if (MulOpSCEV == Ops[AddOp]) {
2877+ if (MulOpSCEV.getCanonical(*this) == Ops[AddOp].getCanonical(*this) ) {
28212878 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
28222879 SCEVUse InnerMul = Mul->getOperand(MulOp == 0);
28232880 if (Mul->getNumOperands() != 2) {
@@ -3018,7 +3075,9 @@ SCEVUse ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
30183075 registerUser(S, Ops);
30193076 }
30203077 S->setNoWrapFlags(Flags);
3021- return S;
3078+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3079+ int UseFlags = IsCanonical ? 0 : 1;
3080+ return {S, UseFlags};
30223081}
30233082
30243083SCEVUse ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
@@ -3063,7 +3122,9 @@ SCEVUse ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
30633122 registerUser(S, Ops);
30643123 }
30653124 S->setNoWrapFlags(Flags);
3066- return S;
3125+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3126+ int UseFlags = IsCanonical ? 0 : 1;
3127+ return {S, UseFlags};
30673128}
30683129
30693130static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
@@ -3165,7 +3226,9 @@ SCEVUse ScalarEvolution::getMulExpr(SmallVectorImpl<SCEVUse> &Ops,
31653226 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
31663227 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
31673228 Mul->setNoWrapFlags(ComputeFlags(Ops));
3168- return S;
3229+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3230+ int UseFlags = IsCanonical ? 0 : 1;
3231+ return {S, UseFlags};
31693232 }
31703233
31713234 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
@@ -13647,6 +13710,19 @@ ScalarEvolution::~ScalarEvolution() {
1364713710 HasRecMap.clear();
1364813711 BackedgeTakenCounts.clear();
1364913712 PredicatedBackedgeTakenCounts.clear();
13713+ UnsignedRanges.clear();
13714+ SignedRanges.clear();
13715+
13716+ BECountUsers.clear();
13717+ SCEVUsers.clear();
13718+ FoldCache.clear();
13719+ FoldCacheUser.clear();
13720+ ValuesAtScopes.clear();
13721+ ValuesAtScopesUsers.clear();
13722+ LoopDispositions.clear();
13723+
13724+ BlockDispositions.clear();
13725+ ConstantMultipleCache.clear();
1365013726
1365113727 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
1365213728 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
0 commit comments