Skip to content

Commit 128dcb1

Browse files
committed
[PatternMatch] Allow m_ConstantInt to match integer splats (llvm#153692)
When matching integers, `m_ConstantInt` is a convenient alternative to `m_APInt` for matching unsigned 64-bit integers, allowing one to simplify ```cpp const APInt *IntC; if (match(V, m_APInt(IntC))) { if (IntC->ule(UINT64_MAX)) { uint64_t Int = IntC->getZExtValue(); // ... } } ``` to ```cpp uint64_t Int; if (match(V, m_ConstantInt(Int))) { // ... } ``` However, this simplification is only true if `V` is a scalar type. Specifically, `m_APInt` also matches integer splats, but `m_ConstantInt` does not. This patch ensures that the matching behaviour of `m_ConstantInt` parallels that of `m_APInt`, and also incorporates it in some obvious places.
1 parent 1b5ca05 commit 128dcb1

File tree

5 files changed

+25
-25
lines changed

5 files changed

+25
-25
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -935,13 +935,14 @@ struct bind_const_intval_ty {
935935

936936
bind_const_intval_ty(uint64_t &V) : VR(V) {}
937937

938-
template <typename ITy> bool match(ITy *V) {
939-
if (const auto *CV = dyn_cast<ConstantInt>(V))
940-
if (CV->getValue().ule(UINT64_MAX)) {
941-
VR = CV->getZExtValue();
942-
return true;
943-
}
944-
return false;
938+
template <typename ITy> bool match(ITy *V) const {
939+
const APInt *ConstInt;
940+
if (!apint_match(ConstInt, /*AllowPoison=*/false).match(V))
941+
return false;
942+
if (ConstInt->getActiveBits() > 64)
943+
return false;
944+
VR = ConstInt->getZExtValue();
945+
return true;
945946
}
946947
};
947948

llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,9 +1676,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
16761676
return m_CombineOr(m_LShr(V, S), m_AShr(V, S));
16771677
};
16781678

1679-
const APInt *Qn = nullptr;
1680-
if (Value * T; match(Exp, m_Shr(m_Value(T), m_APInt(Qn)))) {
1681-
Op.Frac = Qn->getZExtValue();
1679+
uint64_t Qn = 0;
1680+
if (Value *T; match(Exp, m_Shr(m_Value(T), m_ConstantInt(Qn)))) {
1681+
Op.Frac = Qn;
16821682
Exp = T;
16831683
} else {
16841684
Op.Frac = 0;
@@ -1688,9 +1688,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
16881688
return std::nullopt;
16891689

16901690
// Check if there is rounding added.
1691-
const APInt *C = nullptr;
1692-
if (Value * T; Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_APInt(C)))) {
1693-
uint64_t CV = C->getZExtValue();
1691+
uint64_t CV;
1692+
if (Value *T;
1693+
Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_ConstantInt(CV)))) {
16941694
if (CV != 0 && !isPowerOf2_64(CV))
16951695
return std::nullopt;
16961696
if (CV != 0)

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,11 +1114,10 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
11141114
case Instruction::Shl: {
11151115
// We can promote shl(x, cst) if we can promote x. Since shl overwrites the
11161116
// upper bits we can reduce BitsToClear by the shift amount.
1117-
const APInt *Amt;
1118-
if (match(I->getOperand(1), m_APInt(Amt))) {
1117+
uint64_t ShiftAmt;
1118+
if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
11191119
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
11201120
return false;
1121-
uint64_t ShiftAmt = Amt->getZExtValue();
11221121
BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0;
11231122
return true;
11241123
}
@@ -1127,11 +1126,11 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
11271126
case Instruction::LShr: {
11281127
// We can promote lshr(x, cst) if we can promote x. This requires the
11291128
// ultimate 'and' to clear out the high zero bits we're clearing out though.
1130-
const APInt *Amt;
1131-
if (match(I->getOperand(1), m_APInt(Amt))) {
1129+
uint64_t ShiftAmt;
1130+
if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
11321131
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
11331132
return false;
1134-
BitsToClear += Amt->getZExtValue();
1133+
BitsToClear += ShiftAmt;
11351134
if (BitsToClear > V->getType()->getScalarSizeInBits())
11361135
BitsToClear = V->getType()->getScalarSizeInBits();
11371136
return true;

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,11 +1499,11 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
14991499
// trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] < 0 --> ShOp < 0
15001500
// trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] > -1 --> ShOp > -1
15011501
Value *ShOp;
1502-
const APInt *ShAmtC;
1502+
uint64_t ShAmt;
15031503
bool TrueIfSigned;
15041504
if (isSignBitCheck(Pred, C, TrueIfSigned) &&
1505-
match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) &&
1506-
DstBits == SrcBits - ShAmtC->getZExtValue()) {
1505+
match(X, m_Shr(m_Value(ShOp), m_ConstantInt(ShAmt))) &&
1506+
DstBits == SrcBits - ShAmt) {
15071507
return TrueIfSigned ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp,
15081508
ConstantInt::getNullValue(SrcTy))
15091509
: new ICmpInst(ICmpInst::ICMP_SGT, ShOp,

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,10 @@ static void annotateNonNullAndDereferenceable(CallInst *CI, ArrayRef<unsigned> A
317317
annotateDereferenceableBytes(CI, ArgNos, LenC->getZExtValue());
318318
} else if (isKnownNonZero(Size, DL)) {
319319
annotateNonNullNoUndefBasedOnAccess(CI, ArgNos);
320-
const APInt *X, *Y;
320+
uint64_t X, Y;
321321
uint64_t DerefMin = 1;
322-
if (match(Size, m_Select(m_Value(), m_APInt(X), m_APInt(Y)))) {
323-
DerefMin = std::min(X->getZExtValue(), Y->getZExtValue());
322+
if (match(Size, m_Select(m_Value(), m_ConstantInt(X), m_ConstantInt(Y)))) {
323+
DerefMin = std::min(X, Y);
324324
annotateDereferenceableBytes(CI, ArgNos, DerefMin);
325325
}
326326
}

0 commit comments

Comments
 (0)