Skip to content

Commit 98a6b6e

Browse files
authored
[SCEV] Improve code using SCEVPatternMatch (NFC) (#163946)
1 parent f3cd6b7 commit 98a6b6e

File tree

2 files changed

+56
-69
lines changed

2 files changed

+56
-69
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,18 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
256256
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
257257
}
258258

259+
template <typename Op0_t, typename Op1_t>
260+
inline SCEVBinaryExpr_match<SCEVSMaxExpr, Op0_t, Op1_t>
261+
m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1) {
262+
return m_scev_Binary<SCEVSMaxExpr>(Op0, Op1);
263+
}
264+
265+
template <typename Op0_t, typename Op1_t>
266+
inline SCEVBinaryExpr_match<SCEVMinMaxExpr, Op0_t, Op1_t>
267+
m_scev_MinMax(const Op0_t &Op0, const Op1_t &Op1) {
268+
return m_scev_Binary<SCEVMinMaxExpr>(Op0, Op1);
269+
}
270+
259271
/// Match unsigned remainder pattern.
260272
/// Matches patterns generated by getURemExpr.
261273
template <typename Op0_t, typename Op1_t> struct SCEVURem_match {

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 44 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,19 +1840,19 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
18401840
// = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
18411841
// = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
18421842
//
1843-
if (SM->getNumOperands() == 2)
1844-
if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1845-
if (MulLHS->getAPInt().isPowerOf2())
1846-
if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1847-
int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1848-
MulLHS->getAPInt().logBase2();
1849-
Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1850-
return getMulExpr(
1851-
getZeroExtendExpr(MulLHS, Ty),
1852-
getZeroExtendExpr(
1853-
getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1854-
SCEV::FlagNUW, Depth + 1);
1855-
}
1843+
const APInt *C;
1844+
const SCEV *TruncRHS;
1845+
if (match(SM,
1846+
m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1847+
C->isPowerOf2()) {
1848+
int NewTruncBits =
1849+
getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1850+
Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1851+
return getMulExpr(
1852+
getZeroExtendExpr(SM->getOperand(0), Ty),
1853+
getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1854+
SCEV::FlagNUW, Depth + 1);
1855+
}
18561856
}
18571857

18581858
// zext(umin(x, y)) -> umin(zext(x), zext(y))
@@ -3144,20 +3144,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
31443144
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
31453145
if (Ops.size() == 2) {
31463146
// C1*(C2+V) -> C1*C2 + C1*V
3147-
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3148-
// If any of Add's ops are Adds or Muls with a constant, apply this
3149-
// transformation as well.
3150-
//
3151-
// TODO: There are some cases where this transformation is not
3152-
// profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3153-
// this transformation should be narrowed down.
3154-
if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3155-
const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3156-
SCEV::FlagAnyWrap, Depth + 1);
3157-
const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3158-
SCEV::FlagAnyWrap, Depth + 1);
3159-
return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3160-
}
3147+
// If any of Add's ops are Adds or Muls with a constant, apply this
3148+
// transformation as well.
3149+
//
3150+
// TODO: There are some cases where this transformation is not
3151+
// profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3152+
// this transformation should be narrowed down.
3153+
const SCEV *Op0, *Op1;
3154+
if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3155+
containsConstantInAddMulChain(Ops[1])) {
3156+
const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3157+
const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3158+
return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3159+
}
31613160

31623161
if (Ops[0]->isAllOnesValue()) {
31633162
// If we have a mul by -1 of an add, try distributing the -1 among the
@@ -3578,20 +3577,12 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
35783577
}
35793578

35803579
// ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3581-
if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
3582-
AE && AE->getNumOperands() == 2) {
3583-
if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
3584-
const APInt &NegC = VC->getAPInt();
3585-
if (NegC.isNegative() && !NegC.isMinSignedValue()) {
3586-
const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
3587-
if (MME && MME->getNumOperands() == 2 &&
3588-
isa<SCEVConstant>(MME->getOperand(0)) &&
3589-
cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
3590-
MME->getOperand(1) == RHS)
3591-
return getZero(LHS->getType());
3592-
}
3593-
}
3594-
}
3580+
const APInt *NegC, *C;
3581+
if (match(LHS,
3582+
m_scev_Add(m_scev_APInt(NegC),
3583+
m_scev_SMax(m_scev_APInt(C), m_scev_Specific(RHS)))) &&
3584+
NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3585+
return getZero(LHS->getType());
35953586

35963587
// TODO: Generalize to handle any common factors.
35973588
// udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
@@ -10791,19 +10782,15 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) {
1079110782
}
1079210783

1079310784
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10794-
const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10795-
if (!Add || Add->getNumOperands() != 2)
10785+
const SCEV *Op0, *Op1;
10786+
if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
1079610787
return false;
10797-
if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10798-
ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10799-
LHS = Add->getOperand(1);
10800-
RHS = ME->getOperand(1);
10788+
if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10789+
LHS = Op1;
1080110790
return true;
1080210791
}
10803-
if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10804-
ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10805-
LHS = Add->getOperand(0);
10806-
RHS = ME->getOperand(1);
10792+
if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10793+
LHS = Op0;
1080710794
return true;
1080810795
}
1080910796
return false;
@@ -12166,13 +12153,10 @@ bool ScalarEvolution::isImpliedCondBalancedTypes(
1216612153
bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
1216712154
const SCEV *&L, const SCEV *&R,
1216812155
SCEV::NoWrapFlags &Flags) {
12169-
const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
12170-
if (!AE || AE->getNumOperands() != 2)
12156+
if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
1217112157
return false;
1217212158

12173-
L = AE->getOperand(0);
12174-
R = AE->getOperand(1);
12175-
Flags = AE->getNoWrapFlags();
12159+
Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
1217612160
return true;
1217712161
}
1217812162

@@ -15550,19 +15534,10 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1555015534
auto IsMinMaxSCEVWithNonNegativeConstant =
1555115535
[&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
1555215536
const SCEV *&RHS) {
15553-
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15554-
if (MinMax->getNumOperands() != 2)
15555-
return false;
15556-
if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15557-
if (C->getAPInt().isNegative())
15558-
return false;
15559-
SCTy = MinMax->getSCEVType();
15560-
LHS = MinMax->getOperand(0);
15561-
RHS = MinMax->getOperand(1);
15562-
return true;
15563-
}
15564-
}
15565-
return false;
15537+
const APInt *C;
15538+
SCTy = Expr->getSCEVType();
15539+
return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
15540+
match(LHS, m_scev_APInt(C)) && C->isNonNegative();
1556615541
};
1556715542

1556815543
// Return a new SCEV that modifies \p Expr to the closest number divides by

0 commit comments

Comments
 (0)