Skip to content

Commit ee329a0

Browse files
authored
[InstCombine] Fold integer unpack/repack patterns through ZExt (llvm#3662)
2 parents c8e4efe + 81d048a commit ee329a0

File tree

9 files changed

+386
-27
lines changed

9 files changed

+386
-27
lines changed

llvm/include/llvm/IR/IRBuilder.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,10 +1537,14 @@ class IRBuilderBase {
15371537
return Accum;
15381538
}
15391539

1540-
Value *CreateOr(Value *LHS, Value *RHS, const Twine &Name = "") {
1540+
Value *CreateOr(Value *LHS, Value *RHS, const Twine &Name = "",
1541+
bool IsDisjoint = false) {
15411542
if (auto *V = Folder.FoldBinOp(Instruction::Or, LHS, RHS))
15421543
return V;
1543-
return Insert(BinaryOperator::CreateOr(LHS, RHS), Name);
1544+
return Insert(
1545+
IsDisjoint ? BinaryOperator::CreateDisjoint(Instruction::Or, LHS, RHS)
1546+
: BinaryOperator::CreateOr(LHS, RHS),
1547+
Name);
15441548
}
15451549

15461550
Value *CreateOr(Value *LHS, const APInt &RHS, const Twine &Name = "") {

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -935,13 +935,14 @@ struct bind_const_intval_ty {
935935

936936
bind_const_intval_ty(uint64_t &V) : VR(V) {}
937937

938-
template <typename ITy> bool match(ITy *V) {
939-
if (const auto *CV = dyn_cast<ConstantInt>(V))
940-
if (CV->getValue().ule(UINT64_MAX)) {
941-
VR = CV->getZExtValue();
942-
return true;
943-
}
944-
return false;
938+
template <typename ITy> bool match(ITy *V) const {
939+
const APInt *ConstInt;
940+
if (!apint_match(ConstInt, /*AllowPoison=*/false).match(V))
941+
return false;
942+
if (ConstInt->getActiveBits() > 64)
943+
return false;
944+
VR = ConstInt->getZExtValue();
945+
return true;
945946
}
946947
};
947948

llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,9 +1676,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
16761676
return m_CombineOr(m_LShr(V, S), m_AShr(V, S));
16771677
};
16781678

1679-
const APInt *Qn = nullptr;
1680-
if (Value * T; match(Exp, m_Shr(m_Value(T), m_APInt(Qn)))) {
1681-
Op.Frac = Qn->getZExtValue();
1679+
uint64_t Qn = 0;
1680+
if (Value *T; match(Exp, m_Shr(m_Value(T), m_ConstantInt(Qn)))) {
1681+
Op.Frac = Qn;
16821682
Exp = T;
16831683
} else {
16841684
Op.Frac = 0;
@@ -1688,9 +1688,9 @@ auto HvxIdioms::matchFxpMul(Instruction &In) const -> std::optional<FxpOp> {
16881688
return std::nullopt;
16891689

16901690
// Check if there is rounding added.
1691-
const APInt *C = nullptr;
1692-
if (Value * T; Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_APInt(C)))) {
1693-
uint64_t CV = C->getZExtValue();
1691+
uint64_t CV;
1692+
if (Value *T;
1693+
Op.Frac > 0 && match(Exp, m_Add(m_Value(T), m_ConstantInt(CV)))) {
16941694
if (CV != 0 && !isPowerOf2_64(CV))
16951695
return std::nullopt;
16961696
if (CV != 0)

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3549,6 +3549,109 @@ static Value *foldOrOfInversions(BinaryOperator &I,
35493549
return nullptr;
35503550
}
35513551

3552+
/// Match \p V as "lshr -> mask -> zext -> shl".
3553+
///
3554+
/// \p Int is the underlying integer being extracted from.
3555+
/// \p Mask is a bitmask identifying which bits of the integer are being
3556+
/// extracted. \p Offset identifies which bit of the result \p V corresponds to
3557+
/// the least significant bit of \p Int
3558+
static bool matchZExtedSubInteger(Value *V, Value *&Int, APInt &Mask,
3559+
uint64_t &Offset, bool &IsShlNUW,
3560+
bool &IsShlNSW) {
3561+
Value *ShlOp0;
3562+
uint64_t ShlAmt = 0;
3563+
if (!match(V, m_OneUse(m_Shl(m_Value(ShlOp0), m_ConstantInt(ShlAmt)))))
3564+
return false;
3565+
3566+
IsShlNUW = cast<BinaryOperator>(V)->hasNoUnsignedWrap();
3567+
IsShlNSW = cast<BinaryOperator>(V)->hasNoSignedWrap();
3568+
3569+
Value *ZExtOp0;
3570+
if (!match(ShlOp0, m_OneUse(m_ZExt(m_Value(ZExtOp0)))))
3571+
return false;
3572+
3573+
Value *MaskedOp0;
3574+
const APInt *ShiftedMaskConst = nullptr;
3575+
if (!match(ZExtOp0, m_CombineOr(m_OneUse(m_And(m_Value(MaskedOp0),
3576+
m_APInt(ShiftedMaskConst))),
3577+
m_Value(MaskedOp0))))
3578+
return false;
3579+
3580+
uint64_t LShrAmt = 0;
3581+
if (!match(MaskedOp0,
3582+
m_CombineOr(m_OneUse(m_LShr(m_Value(Int), m_ConstantInt(LShrAmt))),
3583+
m_Value(Int))))
3584+
return false;
3585+
3586+
if (LShrAmt > ShlAmt)
3587+
return false;
3588+
Offset = ShlAmt - LShrAmt;
3589+
3590+
Mask = ShiftedMaskConst ? ShiftedMaskConst->shl(LShrAmt)
3591+
: APInt::getBitsSetFrom(
3592+
Int->getType()->getScalarSizeInBits(), LShrAmt);
3593+
3594+
return true;
3595+
}
3596+
3597+
/// Try to fold the join of two scalar integers whose bits are unpacked and
3598+
/// zexted from the same source integer.
3599+
static Value *foldIntegerRepackThroughZExt(Value *Lhs, Value *Rhs,
3600+
InstCombiner::BuilderTy &Builder) {
3601+
3602+
Value *LhsInt, *RhsInt;
3603+
APInt LhsMask, RhsMask;
3604+
uint64_t LhsOffset, RhsOffset;
3605+
bool IsLhsShlNUW, IsLhsShlNSW, IsRhsShlNUW, IsRhsShlNSW;
3606+
if (!matchZExtedSubInteger(Lhs, LhsInt, LhsMask, LhsOffset, IsLhsShlNUW,
3607+
IsLhsShlNSW))
3608+
return nullptr;
3609+
if (!matchZExtedSubInteger(Rhs, RhsInt, RhsMask, RhsOffset, IsRhsShlNUW,
3610+
IsRhsShlNSW))
3611+
return nullptr;
3612+
if (LhsInt != RhsInt || LhsOffset != RhsOffset)
3613+
return nullptr;
3614+
3615+
APInt Mask = LhsMask | RhsMask;
3616+
3617+
Type *DestTy = Lhs->getType();
3618+
Value *Res = Builder.CreateShl(
3619+
Builder.CreateZExt(
3620+
Builder.CreateAnd(LhsInt, Mask, LhsInt->getName() + ".mask"), DestTy,
3621+
LhsInt->getName() + ".zext"),
3622+
ConstantInt::get(DestTy, LhsOffset), "", IsLhsShlNUW && IsRhsShlNUW,
3623+
IsLhsShlNSW && IsRhsShlNSW);
3624+
Res->takeName(Lhs);
3625+
return Res;
3626+
}
3627+
3628+
Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS) {
3629+
if (Value *Res = foldIntegerRepackThroughZExt(LHS, RHS, Builder))
3630+
return Res;
3631+
3632+
return nullptr;
3633+
}
3634+
3635+
Value *InstCombinerImpl::reassociateDisjointOr(Value *LHS, Value *RHS) {
3636+
3637+
Value *X, *Y;
3638+
if (match(RHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
3639+
if (Value *Res = foldDisjointOr(LHS, X))
3640+
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
3641+
if (Value *Res = foldDisjointOr(LHS, Y))
3642+
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
3643+
}
3644+
3645+
if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
3646+
if (Value *Res = foldDisjointOr(X, RHS))
3647+
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
3648+
if (Value *Res = foldDisjointOr(Y, RHS))
3649+
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
3650+
}
3651+
3652+
return nullptr;
3653+
}
3654+
35523655
/// Match \p V as "shufflevector -> bitcast" or "extractelement -> zext -> shl"
35533656
/// patterns, which extract vector elements and pack them in the same relative
35543657
/// positions.
@@ -3781,6 +3884,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37813884
foldAddLikeCommutative(I.getOperand(1), I.getOperand(0),
37823885
/*NSW=*/true, /*NUW=*/true))
37833886
return R;
3887+
3888+
if (Value *Res = foldDisjointOr(I.getOperand(0), I.getOperand(1)))
3889+
return replaceInstUsesWith(I, Res);
3890+
3891+
if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1)))
3892+
return replaceInstUsesWith(I, Res);
37843893
}
37853894

37863895
Value *X, *Y;

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,11 +1114,10 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
11141114
case Instruction::Shl: {
11151115
// We can promote shl(x, cst) if we can promote x. Since shl overwrites the
11161116
// upper bits we can reduce BitsToClear by the shift amount.
1117-
const APInt *Amt;
1118-
if (match(I->getOperand(1), m_APInt(Amt))) {
1117+
uint64_t ShiftAmt;
1118+
if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
11191119
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
11201120
return false;
1121-
uint64_t ShiftAmt = Amt->getZExtValue();
11221121
BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0;
11231122
return true;
11241123
}
@@ -1127,11 +1126,11 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
11271126
case Instruction::LShr: {
11281127
// We can promote lshr(x, cst) if we can promote x. This requires the
11291128
// ultimate 'and' to clear out the high zero bits we're clearing out though.
1130-
const APInt *Amt;
1131-
if (match(I->getOperand(1), m_APInt(Amt))) {
1129+
uint64_t ShiftAmt;
1130+
if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
11321131
if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
11331132
return false;
1134-
BitsToClear += Amt->getZExtValue();
1133+
BitsToClear += ShiftAmt;
11351134
if (BitsToClear > V->getType()->getScalarSizeInBits())
11361135
BitsToClear = V->getType()->getScalarSizeInBits();
11371136
return true;

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,11 +1499,11 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
14991499
// trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] < 0 --> ShOp < 0
15001500
// trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] > -1 --> ShOp > -1
15011501
Value *ShOp;
1502-
const APInt *ShAmtC;
1502+
uint64_t ShAmt;
15031503
bool TrueIfSigned;
15041504
if (isSignBitCheck(Pred, C, TrueIfSigned) &&
1505-
match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) &&
1506-
DstBits == SrcBits - ShAmtC->getZExtValue()) {
1505+
match(X, m_Shr(m_Value(ShOp), m_ConstantInt(ShAmt))) &&
1506+
DstBits == SrcBits - ShAmt) {
15071507
return TrueIfSigned ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp,
15081508
ConstantInt::getNullValue(SrcTy))
15091509
: new ICmpInst(ICmpInst::ICMP_SGT, ShOp,

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
432432
Value *reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, Instruction &I,
433433
bool IsAnd, bool RHSIsLogical);
434434

435+
Value *foldDisjointOr(Value *LHS, Value *RHS);
436+
437+
Value *reassociateDisjointOr(Value *LHS, Value *RHS);
438+
435439
Instruction *
436440
canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i);
437441

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,10 @@ static void annotateNonNullAndDereferenceable(CallInst *CI, ArrayRef<unsigned> A
317317
annotateDereferenceableBytes(CI, ArgNos, LenC->getZExtValue());
318318
} else if (isKnownNonZero(Size, DL)) {
319319
annotateNonNullNoUndefBasedOnAccess(CI, ArgNos);
320-
const APInt *X, *Y;
320+
uint64_t X, Y;
321321
uint64_t DerefMin = 1;
322-
if (match(Size, m_Select(m_Value(), m_APInt(X), m_APInt(Y)))) {
323-
DerefMin = std::min(X->getZExtValue(), Y->getZExtValue());
322+
if (match(Size, m_Select(m_Value(), m_ConstantInt(X), m_ConstantInt(Y)))) {
323+
DerefMin = std::min(X, Y);
324324
annotateDereferenceableBytes(CI, ArgNos, DerefMin);
325325
}
326326
}

0 commit comments

Comments
 (0)