@@ -1840,19 +1840,19 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
18401840 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
18411841 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
18421842 //
1843- if (SM->getNumOperands() == 2)
1844- if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1845- if (MulLHS->getAPInt().isPowerOf2())
1846- if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1847- int NewTruncBits = getTypeSizeInBits(TruncRHS->getType ()) -
1848- MulLHS->getAPInt().logBase2();
1849- Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits );
1850- return getMulExpr(
1851- getZeroExtendExpr(MulLHS, Ty),
1852- getZeroExtendExpr(
1853- getTruncateExpr(TruncRHS->getOperand() , NewTruncTy), Ty),
1854- SCEV::FlagNUW, Depth + 1);
1855- }
1843+ const APInt *C;
1844+ const SCEV *TruncRHS;
1845+ if (match(SM,
1846+ m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1847+ C->isPowerOf2 ()) {
1848+ int NewTruncBits =
1849+ getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2( );
1850+ Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1851+ return getMulExpr(
1852+ getZeroExtendExpr(SM->getOperand(0), Ty),
1853+ getZeroExtendExpr( getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1854+ SCEV::FlagNUW, Depth + 1);
1855+ }
18561856 }
18571857
18581858 // zext(umin(x, y)) -> umin(zext(x), zext(y))
@@ -3144,20 +3144,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
31443144 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
31453145 if (Ops.size() == 2) {
31463146 // C1*(C2+V) -> C1*C2 + C1*V
3147- if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3148- // If any of Add's ops are Adds or Muls with a constant, apply this
3149- // transformation as well.
3150- //
3151- // TODO: There are some cases where this transformation is not
3152- // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3153- // this transformation should be narrowed down.
3154- if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3155- const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3156- SCEV::FlagAnyWrap, Depth + 1);
3157- const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3158- SCEV::FlagAnyWrap, Depth + 1);
3159- return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3160- }
3147+ // If any of Add's ops are Adds or Muls with a constant, apply this
3148+ // transformation as well.
3149+ //
3150+ // TODO: There are some cases where this transformation is not
3151+ // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3152+ // this transformation should be narrowed down.
3153+ const SCEV *Op0, *Op1;
3154+ if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3155+ containsConstantInAddMulChain(Ops[1])) {
3156+ const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3157+ const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3158+ return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3159+ }
31613160
31623161 if (Ops[0]->isAllOnesValue()) {
31633162 // If we have a mul by -1 of an add, try distributing the -1 among the
@@ -3578,20 +3577,12 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
35783577 }
35793578
35803579 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3581- if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3582- AE && AE->getNumOperands() == 2) {
3583- if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3584- const APInt &NegC = VC->getAPInt();
3585- if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3586- const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3587- if (MME && MME->getNumOperands() == 2 &&
3588- isa<SCEVConstant>(MME->getOperand(0)) &&
3589- cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3590- MME->getOperand(1) == RHS)
3591- return getZero(LHS->getType());
3592- }
3593- }
3594- }
3580+ const APInt *NegC, *C;
3581+ if (match(LHS,
3582+ m_scev_Add(m_scev_APInt(NegC),
3583+ m_scev_SMax(m_scev_APInt(C), m_scev_Specific(RHS)))) &&
3584+ NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3585+ return getZero(LHS->getType());
35953586
35963587 // TODO: Generalize to handle any common factors.
35973588 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
@@ -10791,19 +10782,15 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) {
1079110782}
1079210783
1079310784static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10794- const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S) ;
10795- if (!Add || Add->getNumOperands() != 2 )
10785+ const SCEV *Op0, *Op1 ;
10786+ if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) )
1079610787 return false;
10797- if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10798- ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10799- LHS = Add->getOperand(1);
10800- RHS = ME->getOperand(1);
10788+ if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10789+ LHS = Op1;
1080110790 return true;
1080210791 }
10803- if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10804- ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10805- LHS = Add->getOperand(0);
10806- RHS = ME->getOperand(1);
10792+ if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10793+ LHS = Op0;
1080710794 return true;
1080810795 }
1080910796 return false;
@@ -12166,13 +12153,10 @@ bool ScalarEvolution::isImpliedCondBalancedTypes(
1216612153bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
1216712154 const SCEV *&L, const SCEV *&R,
1216812155 SCEV::NoWrapFlags &Flags) {
12169- const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12170- if (!AE || AE->getNumOperands() != 2)
12156+ if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
1217112157 return false;
1217212158
12173- L = AE->getOperand(0);
12174- R = AE->getOperand(1);
12175- Flags = AE->getNoWrapFlags();
12159+ Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
1217612160 return true;
1217712161}
1217812162
@@ -15550,19 +15534,10 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1555015534 auto IsMinMaxSCEVWithNonNegativeConstant =
1555115535 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
1555215536 const SCEV *&RHS) {
15553- if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15554- if (MinMax->getNumOperands() != 2)
15555- return false;
15556- if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15557- if (C->getAPInt().isNegative())
15558- return false;
15559- SCTy = MinMax->getSCEVType();
15560- LHS = MinMax->getOperand(0);
15561- RHS = MinMax->getOperand(1);
15562- return true;
15563- }
15564- }
15565- return false;
15537+ const APInt *C;
15538+ SCTy = Expr->getSCEVType();
15539+ return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
15540+ match(LHS, m_scev_APInt(C)) && C->isNonNegative();
1556615541 };
1556715542
1556815543 // Return a new SCEV that modifies \p Expr to the closest number divides by
0 commit comments