@@ -831,6 +831,49 @@ static bool hasHugeExpression(ArrayRef<const SCEV *> Ops) {
831831 });
832832}
833833
834+ /// Performs a number of common optimizations on the passed \p Ops. If the
835+ /// whole expression reduces down to a single operand, it will be returned.
836+ ///
837+ /// The following optimizations are performed:
838+ /// * Fold constants using the \p Fold function.
839+ /// * Remove identity constants satisfying \p IsIdentity.
840+ /// * If a constant satisfies \p IsAbsorber, return it.
841+ /// * Sort operands by complexity.
842+ template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
843+ static const SCEV *
844+ constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
845+ SmallVectorImpl<const SCEV *> &Ops, FoldT Fold,
846+ IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
847+ const SCEVConstant *Folded = nullptr;
848+ for (unsigned Idx = 0; Idx < Ops.size();) {
849+ const SCEV *Op = Ops[Idx];
850+ if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
851+ if (!Folded)
852+ Folded = C;
853+ else
854+ Folded = cast<SCEVConstant>(
855+ SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
856+ Ops.erase(Ops.begin() + Idx);
857+ continue;
858+ }
859+ ++Idx;
860+ }
861+
862+ if (Ops.empty()) {
863+ assert(Folded && "Must have folded value");
864+ return Folded;
865+ }
866+
867+ if (Folded && IsAbsorber(Folded->getAPInt()))
868+ return Folded;
869+
870+ GroupByComplexity(Ops, &LI, DT);
871+ if (Folded && !IsIdentity(Folded->getAPInt()))
872+ Ops.insert(Ops.begin(), Folded);
873+
874+ return Ops.size() == 1 ? Ops[0] : nullptr;
875+ }
876+
834877//===----------------------------------------------------------------------===//
835878// Simple SCEV method implementations
836879//===----------------------------------------------------------------------===//
@@ -2504,30 +2547,15 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
25042547 assert(NumPtrs <= 1 && "add has at most one pointer operand");
25052548#endif
25062549
2507- // Sort by complexity, this groups all similar expression types together.
2508- GroupByComplexity(Ops, &LI, DT);
2509-
2510- // If there are any constants, fold them together.
2511- unsigned Idx = 0;
2512- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2513- ++Idx;
2514- assert(Idx < Ops.size());
2515- while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2516- // We found two constants, fold them together!
2517- Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2518- if (Ops.size() == 2) return Ops[0];
2519- Ops.erase(Ops.begin()+1); // Erase the folded element
2520- LHSC = cast<SCEVConstant>(Ops[0]);
2521- }
2522-
2523- // If we are left with a constant zero being added, strip it off.
2524- if (LHSC->getValue()->isZero()) {
2525- Ops.erase(Ops.begin());
2526- --Idx;
2527- }
2550+ const SCEV *Folded = constantFoldAndGroupOps(
2551+ *this, LI, DT, Ops,
2552+ [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2553+ [](const APInt &C) { return C.isZero(); }, // identity
2554+ [](const APInt &C) { return false; }); // absorber
2555+ if (Folded)
2556+ return Folded;
25282557
2529- if (Ops.size() == 1) return Ops[0];
2530- }
2558+ unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
25312559
25322560 // Delay expensive flag strengthening until necessary.
25332561 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
@@ -3097,35 +3125,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
30973125 "SCEVMulExpr operand types don't match!");
30983126#endif
30993127
3100- // Sort by complexity, this groups all similar expression types together.
3101- GroupByComplexity(Ops, &LI, DT);
3102-
3103- // If there are any constants, fold them together.
3104- unsigned Idx = 0;
3105- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3106- ++Idx;
3107- assert(Idx < Ops.size());
3108- while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3109- // We found two constants, fold them together!
3110- Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3111- if (Ops.size() == 2) return Ops[0];
3112- Ops.erase(Ops.begin()+1); // Erase the folded element
3113- LHSC = cast<SCEVConstant>(Ops[0]);
3114- }
3115-
3116- // If we have a multiply of zero, it will always be zero.
3117- if (LHSC->getValue()->isZero())
3118- return LHSC;
3119-
3120- // If we are left with a constant one being multiplied, strip it off.
3121- if (LHSC->getValue()->isOne()) {
3122- Ops.erase(Ops.begin());
3123- --Idx;
3124- }
3125-
3126- if (Ops.size() == 1)
3127- return Ops[0];
3128- }
3128+ const SCEV *Folded = constantFoldAndGroupOps(
3129+ *this, LI, DT, Ops,
3130+ [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3131+ [](const APInt &C) { return C.isOne(); }, // identity
3132+ [](const APInt &C) { return C.isZero(); }); // absorber
3133+ if (Folded)
3134+ return Folded;
31293135
31303136 // Delay expensive flag strengthening until necessary.
31313137 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
@@ -3202,6 +3208,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
32023208 }
32033209
32043210 // Skip over the add expression until we get to a multiply.
3211+ unsigned Idx = 0;
32053212 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
32063213 ++Idx;
32073214
@@ -3829,61 +3836,46 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
38293836 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
38303837 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
38313838
3832- // Sort by complexity, this groups all similar expression types together.
3833- GroupByComplexity(Ops, &LI, DT);
3839+ const SCEV *Folded = constantFoldAndGroupOps(
3840+ *this, LI, DT, Ops,
3841+ [&](const APInt &C1, const APInt &C2) {
3842+ switch (Kind) {
3843+ case scSMaxExpr:
3844+ return APIntOps::smax(C1, C2);
3845+ case scSMinExpr:
3846+ return APIntOps::smin(C1, C2);
3847+ case scUMaxExpr:
3848+ return APIntOps::umax(C1, C2);
3849+ case scUMinExpr:
3850+ return APIntOps::umin(C1, C2);
3851+ default:
3852+ llvm_unreachable("Unknown SCEV min/max opcode");
3853+ }
3854+ },
3855+ [&](const APInt &C) {
3856+ // identity
3857+ if (IsMax)
3858+ return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3859+ else
3860+ return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3861+ },
3862+ [&](const APInt &C) {
3863+ // absorber
3864+ if (IsMax)
3865+ return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3866+ else
3867+ return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3868+ });
3869+ if (Folded)
3870+ return Folded;
38343871
38353872 // Check if we have created the same expression before.
38363873 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
38373874 return S;
38383875 }
38393876
3840- // If there are any constants, fold them together.
3841- unsigned Idx = 0;
3842- if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3843- ++Idx;
3844- assert(Idx < Ops.size());
3845- auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3846- switch (Kind) {
3847- case scSMaxExpr:
3848- return APIntOps::smax(LHS, RHS);
3849- case scSMinExpr:
3850- return APIntOps::smin(LHS, RHS);
3851- case scUMaxExpr:
3852- return APIntOps::umax(LHS, RHS);
3853- case scUMinExpr:
3854- return APIntOps::umin(LHS, RHS);
3855- default:
3856- llvm_unreachable("Unknown SCEV min/max opcode");
3857- }
3858- };
3859-
3860- while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3861- // We found two constants, fold them together!
3862- ConstantInt *Fold = ConstantInt::get(
3863- getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3864- Ops[0] = getConstant(Fold);
3865- Ops.erase(Ops.begin()+1); // Erase the folded element
3866- if (Ops.size() == 1) return Ops[0];
3867- LHSC = cast<SCEVConstant>(Ops[0]);
3868- }
3869-
3870- bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3871- bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3872-
3873- if (IsMax ? IsMinV : IsMaxV) {
3874- // If we are left with a constant minimum(/maximum)-int, strip it off.
3875- Ops.erase(Ops.begin());
3876- --Idx;
3877- } else if (IsMax ? IsMaxV : IsMinV) {
3878- // If we have a max(/min) with a constant maximum(/minimum)-int,
3879- // it will always be the extremum.
3880- return LHSC;
3881- }
3882-
3883- if (Ops.size() == 1) return Ops[0];
3884- }
3885-
38863877 // Find the first operation of the same kind
3878+ unsigned Idx = 0;
38873879 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
38883880 ++Idx;
38893881
0 commit comments