Skip to content

Commit a8d2568

Browse files
authored
[PatternMatch] Allow m_ConstantInt to match integer splats (#153692)
When matching integers, `m_ConstantInt` is a convenient alternative to `m_APInt` for matching unsigned 64-bit integers, allowing one to simplify ```cpp const APInt *IntC; if (match(V, m_APInt(IntC))) { if (IntC->ule(UINT64_MAX)) { uint64_t Int = IntC->getZExtValue(); // ... } } ``` to ```cpp uint64_t Int; if (match(V, m_ConstantInt(Int))) { // ... } ``` However, this simplification is only true if `V` is a scalar type. Specifically, `m_APInt` also matches integer splats, but `m_ConstantInt` does not. This patch ensures that the matching behaviour of `m_ConstantInt` parallels that of `m_APInt`, and also incorporates it in some obvious places.
1 parent af96ed6 commit a8d2568

File tree

6 files changed

+28
-28
lines changed

6 files changed

+28
-28
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,12 +1013,13 @@ struct bind_const_intval_ty {
10131013
bind_const_intval_ty(uint64_t &V) : VR(V) {}
10141014

10151015
template <typename ITy> bool match(ITy *V) const {
1016-
if (const auto *CV = dyn_cast<ConstantInt>(V))
1017-
if (CV->getValue().ule(UINT64_MAX)) {
1018-
VR = CV->getZExtValue();
1019-
return true;
1020-
}
1021-
return false;
1016+
const APInt *ConstInt;
1017+
if (!apint_match(ConstInt, /*AllowPoison=*/false).match(V))
1018+
return false;
1019+
if (ConstInt->getActiveBits() > 64)
1020+
return false;
1021+
VR = ConstInt->getZExtValue();
1022+
return true;
10221023
}
10231024
};
10241025

llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,9 +1677,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
16771677
return m_CombineOr(m_LShr(V, S), m_AShr(V, S));
16781678
};
16791679

1680-
const APInt *Qn = nullptr;
1681-
if (Value * T; match(Exp, m_Shr(m_Value(T), m_APInt(Qn)))) {
1682-
Op.Frac = Qn->getZExtValue();
1680+
uint64_t Qn = 0;
1681+
if (Value *T; match(Exp, m_Shr(m_Value(T), m_ConstantInt(Qn)))) {
1682+
Op.Frac = Qn;
16831683
Exp = T;
16841684
} else {
16851685
Op.Frac = 0;
@@ -1689,9 +1689,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
16891689
return std::nullopt;
16901690

16911691
// Check if there is rounding added.
1692-
const APInt *C = nullptr;
1693-
if (Value * T; Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_APInt(C)))) {
1694-
uint64_t CV = C->getZExtValue();
1692+
uint64_t CV;
1693+
if (Value *T;
1694+
Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_ConstantInt(CV)))) {
16951695
if (CV != 0 && !isPowerOf2_64(CV))
16961696
return std::nullopt;
16971697
if (CV != 0)

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,11 +1131,10 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
11311131
case Instruction::Shl: {
11321132
// We can promote shl(x, cst) if we can promote x. Since shl overwrites the
11331133
// upper bits we can reduce BitsToClear by the shift amount.
1134-
const APInt *Amt;
1135-
if (match(I->getOperand(1), m_APInt(Amt))) {
1134+
uint64_t ShiftAmt;
1135+
if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
11361136
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
11371137
return false;
1138-
uint64_t ShiftAmt = Amt->getZExtValue();
11391138
BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0;
11401139
return true;
11411140
}
@@ -1144,11 +1143,11 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
11441143
case Instruction::LShr: {
11451144
// We can promote lshr(x, cst) if we can promote x. This requires the
11461145
// ultimate 'and' to clear out the high zero bits we're clearing out though.
1147-
const APInt *Amt;
1148-
if (match(I->getOperand(1), m_APInt(Amt))) {
1146+
uint64_t ShiftAmt;
1147+
if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
11491148
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
11501149
return false;
1151-
BitsToClear += Amt->getZExtValue();
1150+
BitsToClear += ShiftAmt;
11521151
if (BitsToClear > V->getType()->getScalarSizeInBits())
11531152
BitsToClear = V->getType()->getScalarSizeInBits();
11541153
return true;

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,11 +1550,11 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
15501550
// trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] < 0 --> ShOp < 0
15511551
// trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] > -1 --> ShOp > -1
15521552
Value *ShOp;
1553-
const APInt *ShAmtC;
1553+
uint64_t ShAmt;
15541554
bool TrueIfSigned;
15551555
if (isSignBitCheck(Pred, C, TrueIfSigned) &&
1556-
match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) &&
1557-
DstBits == SrcBits - ShAmtC->getZExtValue()) {
1556+
match(X, m_Shr(m_Value(ShOp), m_ConstantInt(ShAmt))) &&
1557+
DstBits == SrcBits - ShAmt) {
15581558
return TrueIfSigned ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp,
15591559
ConstantInt::getNullValue(SrcTy))
15601560
: new ICmpInst(ICmpInst::ICMP_SGT, ShOp,

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,10 @@ static void annotateNonNullAndDereferenceable(CallInst *CI, ArrayRef<unsigned> A
319319
annotateDereferenceableBytes(CI, ArgNos, LenC->getZExtValue());
320320
} else if (isKnownNonZero(Size, DL)) {
321321
annotateNonNullNoUndefBasedOnAccess(CI, ArgNos);
322-
const APInt *X, *Y;
322+
uint64_t X, Y;
323323
uint64_t DerefMin = 1;
324-
if (match(Size, m_Select(m_Value(), m_APInt(X), m_APInt(Y)))) {
325-
DerefMin = std::min(X->getZExtValue(), Y->getZExtValue());
324+
if (match(Size, m_Select(m_Value(), m_ConstantInt(X), m_ConstantInt(Y)))) {
325+
DerefMin = std::min(X, Y);
326326
annotateDereferenceableBytes(CI, ArgNos, DerefMin);
327327
}
328328
}

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,15 +1863,15 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {
18631863
unsigned ExtCnt = 0;
18641864
bool ExtLane0 = false;
18651865
for (User *U : Ext->users()) {
1866-
const APInt *Idx;
1867-
if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
1866+
uint64_t Idx;
1867+
if (!match(U, m_ExtractElt(m_Value(), m_ConstantInt(Idx))))
18681868
return false;
18691869
if (cast<Instruction>(U)->use_empty())
18701870
continue;
18711871
ExtCnt += 1;
1872-
ExtLane0 |= Idx->isZero();
1872+
ExtLane0 |= !Idx;
18731873
VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
1874-
CostKind, Idx->getZExtValue(), U);
1874+
CostKind, Idx, U);
18751875
}
18761876

18771877
InstructionCost ScalarCost =

0 commit comments

Comments
 (0)