@@ -253,6 +253,59 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
253253// SCEV class definitions
254254//===----------------------------------------------------------------------===//
255255
256+ class SCEVDropFlags : public SCEVRewriteVisitor<SCEVDropFlags> {
257+ using Base = SCEVRewriteVisitor<SCEVDropFlags>;
258+
259+ public:
260+ SCEVDropFlags(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
261+
262+ static SCEVUse rewrite(SCEVUse Scev, ScalarEvolution &SE) {
263+ SCEVDropFlags Rewriter(SE);
264+ return Rewriter.visit(Scev);
265+ }
266+
267+ SCEVUse visitAddExpr(const SCEVAddExpr *Expr) {
268+ SmallVector<const SCEV *, 2> Operands;
269+ bool Changed = false;
270+ for (const auto Op : Expr->operands()) {
271+ Operands.push_back(visit(Op));
272+ Changed |= Op != Operands.back();
273+ }
274+ return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
275+ }
276+
277+ SCEVUse visitMulExpr(const SCEVMulExpr *Expr) {
278+ SmallVector<SCEVUse, 2> Operands;
279+ bool Changed = false;
280+ for (const auto Op : Expr->operands()) {
281+ Operands.push_back(visit(Op));
282+ Changed |= Op != Operands.back();
283+ }
284+ return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
285+ }
286+ };
287+
288+ const SCEV *SCEVUse::computeCanonical(ScalarEvolution &SE) const {
289+ return SCEVDropFlags::rewrite(*this, SE);
290+ }
291+
292+ bool SCEVUse::computeIsCanonical() const {
293+ if (!getRawPointer() ||
294+ DenseMapInfo<SCEVUse>::getEmptyKey().getRawPointer() == getRawPointer() ||
295+ DenseMapInfo<SCEVUse>::getTombstoneKey().getRawPointer() ==
296+ getRawPointer() ||
297+ isa<SCEVCouldNotCompute>(this))
298+ return true;
299+ return !SCEVExprContains(*this, [](SCEVUse U) { return U.getFlags() != 0; });
300+ }
301+
302+ bool SCEVUse::operator==(const SCEVUse &RHS) const {
303+ assert(isCanonical() && RHS.isCanonical());
304+ return getPointer() == RHS.getPointer();
305+ }
306+
307+ bool SCEVUse::operator==(const SCEV *RHS) const { return getPointer() == RHS; }
308+
256309#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
257310LLVM_DUMP_METHOD void SCEVUse::dump() const {
258311 print(dbgs());
@@ -677,9 +730,10 @@ static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
677730static std::optional<int>
678731CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
679732 const LoopInfo *const LI, SCEVUse LHS, SCEVUse RHS,
680- DominatorTree &DT, unsigned Depth = 0) {
733+ DominatorTree &DT, ScalarEvolution &SE,
734+ unsigned Depth = 0) {
681735 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
682- if (LHS == RHS)
736+ if (LHS.getCanonical(SE) == RHS.getCanonical(SE) )
683737 return 0;
684738
685739 // Primarily, sort the SCEVs by their getSCEVType().
@@ -769,7 +823,7 @@ CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
769823 return (int)LNumOps - (int)RNumOps;
770824
771825 for (unsigned i = 0; i != LNumOps; ++i) {
772- auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT,
826+ auto X = CompareSCEVComplexity(EqCacheSCEV, LI, LOps[i], ROps[i], DT, SE,
773827 Depth + 1);
774828 if (X != 0)
775829 return X;
@@ -794,14 +848,14 @@ CompareSCEVComplexity(EquivalenceClasses<SCEVUse> &EqCacheSCEV,
794848/// this to depend on where the addresses of various SCEV objects happened to
795849/// land in memory.
796850static void GroupByComplexity(SmallVectorImpl<SCEVUse> &Ops, LoopInfo *LI,
797- DominatorTree &DT) {
851+ DominatorTree &DT, ScalarEvolution &SE ) {
798852 if (Ops.size() < 2) return; // Noop
799853
800854 EquivalenceClasses<SCEVUse> EqCacheSCEV;
801855
802856 // Whether LHS has provably less complexity than RHS.
803857 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
804- auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT);
858+ auto Complexity = CompareSCEVComplexity(EqCacheSCEV, LI, LHS, RHS, DT, SE );
805859 return Complexity && *Complexity < 0;
806860 };
807861 if (Ops.size() == 2) {
@@ -882,7 +936,7 @@ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
882936 if (Folded && IsAbsorber(Folded->getAPInt()))
883937 return Folded;
884938
885- GroupByComplexity(Ops, &LI, DT);
939+ GroupByComplexity(Ops, &LI, DT, SE );
886940 if (Folded && !IsIdentity(Folded->getAPInt()))
887941 Ops.insert(Ops.begin(), Folded);
888942
@@ -2585,7 +2639,9 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
25852639 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
25862640 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
25872641 Add->setNoWrapFlags(ComputeFlags(Ops));
2588- return S;
2642+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
2643+ int UseFlags = IsCanonical ? 0 : 1;
2644+ return {S, UseFlags};
25892645 }
25902646
25912647 // Okay, check to see if the same value occurs in the operand list more than
@@ -2594,7 +2650,8 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
25942650 Type *Ty = Ops[0]->getType();
25952651 bool FoundMatch = false;
25962652 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2597- if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2653+ if (Ops[i].getCanonical(*this) ==
2654+ Ops[i + 1].getCanonical(*this)) { // X + Y + Y --> X + Y*2
25982655 // Scan ahead to count how many equal operands there are.
25992656 unsigned Count = 2;
26002657 while (i+Count != e && Ops[i+Count] == Ops[i])
@@ -2816,7 +2873,7 @@ SCEVUse ScalarEvolution::getAddExpr(SmallVectorImpl<SCEVUse> &Ops,
28162873 if (isa<SCEVConstant>(MulOpSCEV))
28172874 continue;
28182875 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2819- if (MulOpSCEV == Ops[AddOp]) {
2876+ if (MulOpSCEV.getCanonical(*this) == Ops[AddOp].getCanonical(*this) ) {
28202877 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
28212878 SCEVUse InnerMul = Mul->getOperand(MulOp == 0);
28222879 if (Mul->getNumOperands() != 2) {
@@ -3017,7 +3074,9 @@ SCEVUse ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
30173074 registerUser(S, Ops);
30183075 }
30193076 S->setNoWrapFlags(Flags);
3020- return S;
3077+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3078+ int UseFlags = IsCanonical ? 0 : 1;
3079+ return {S, UseFlags};
30213080}
30223081
30233082SCEVUse ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
@@ -3062,7 +3121,9 @@ SCEVUse ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
30623121 registerUser(S, Ops);
30633122 }
30643123 S->setNoWrapFlags(Flags);
3065- return S;
3124+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3125+ int UseFlags = IsCanonical ? 0 : 1;
3126+ return {S, UseFlags};
30663127}
30673128
30683129static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
@@ -3164,7 +3225,9 @@ SCEVUse ScalarEvolution::getMulExpr(SmallVectorImpl<SCEVUse> &Ops,
31643225 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
31653226 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
31663227 Mul->setNoWrapFlags(ComputeFlags(Ops));
3167- return S;
3228+ bool IsCanonical = all_of(Ops, [](SCEVUse U) { return U.getFlags() == 0; });
3229+ int UseFlags = IsCanonical ? 0 : 1;
3230+ return {S, UseFlags};
31683231 }
31693232
31703233 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
@@ -13646,6 +13709,19 @@ ScalarEvolution::~ScalarEvolution() {
1364613709 HasRecMap.clear();
1364713710 BackedgeTakenCounts.clear();
1364813711 PredicatedBackedgeTakenCounts.clear();
13712+ UnsignedRanges.clear();
13713+ SignedRanges.clear();
13714+
13715+ BECountUsers.clear();
13716+ SCEVUsers.clear();
13717+ FoldCache.clear();
13718+ FoldCacheUser.clear();
13719+ ValuesAtScopes.clear();
13720+ ValuesAtScopesUsers.clear();
13721+ LoopDispositions.clear();
13722+
13723+ BlockDispositions.clear();
13724+ ConstantMultipleCache.clear();
1364913725
1365013726 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
1365113727 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
0 commit comments