Skip to content

Commit 1c49935

Browse files
authored
[PatternMatch] Add m_[Shift]OrSelf matchers. (#152924)
Address the comment https://github.com/llvm/llvm-project/pull/147414/files#r2228612726. As they are usually used to match integer packing patterns, it is enough to handle constant shamts.
1 parent 628c0e3 commit 1c49935

File tree

4 files changed

+91
-33
lines changed

4 files changed

+91
-33
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,45 @@ inline BinaryOp_match<LHS, RHS, Instruction::AShr> m_AShr(const LHS &L,
13271327
return BinaryOp_match<LHS, RHS, Instruction::AShr>(L, R);
13281328
}
13291329

1330+
template <typename LHS_t, unsigned Opcode> struct ShiftLike_match {
1331+
LHS_t L;
1332+
uint64_t &R;
1333+
1334+
ShiftLike_match(const LHS_t &LHS, uint64_t &RHS) : L(LHS), R(RHS) {}
1335+
1336+
template <typename OpTy> bool match(OpTy *V) const {
1337+
if (auto *Op = dyn_cast<BinaryOperator>(V)) {
1338+
if (Op->getOpcode() == Opcode)
1339+
return m_ConstantInt(R).match(Op->getOperand(1)) &&
1340+
L.match(Op->getOperand(0));
1341+
}
1342+
// Interpreted as shiftop V, 0
1343+
R = 0;
1344+
return L.match(V);
1345+
}
1346+
};
1347+
1348+
/// Matches shl L, ConstShAmt or L itself.
1349+
template <typename LHS>
1350+
inline ShiftLike_match<LHS, Instruction::Shl> m_ShlOrSelf(const LHS &L,
1351+
uint64_t &R) {
1352+
return ShiftLike_match<LHS, Instruction::Shl>(L, R);
1353+
}
1354+
1355+
/// Matches lshr L, ConstShAmt or L itself.
1356+
template <typename LHS>
1357+
inline ShiftLike_match<LHS, Instruction::LShr> m_LShrOrSelf(const LHS &L,
1358+
uint64_t &R) {
1359+
return ShiftLike_match<LHS, Instruction::LShr>(L, R);
1360+
}
1361+
1362+
/// Matches ashr L, ConstShAmt or L itself.
1363+
template <typename LHS>
1364+
inline ShiftLike_match<LHS, Instruction::AShr> m_AShrOrSelf(const LHS &L,
1365+
uint64_t &R) {
1366+
return ShiftLike_match<LHS, Instruction::AShr>(L, R);
1367+
}
1368+
13301369
template <typename LHS_t, typename RHS_t, unsigned Opcode,
13311370
unsigned WrapFlags = 0, bool Commutable = false>
13321371
struct OverflowingBinaryOp_match {

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ struct LoadOps {
617617
LoadInst *RootInsert = nullptr;
618618
bool FoundRoot = false;
619619
uint64_t LoadSize = 0;
620-
const APInt *Shift = nullptr;
620+
uint64_t Shift = 0;
621621
Type *ZextType;
622622
AAMDNodes AATags;
623623
};
@@ -627,17 +627,15 @@ struct LoadOps {
627627
// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
628628
static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
629629
AliasAnalysis &AA) {
630-
const APInt *ShAmt2 = nullptr;
630+
uint64_t ShAmt2;
631631
Value *X;
632632
Instruction *L1, *L2;
633633

634634
// Go to the last node with loads.
635-
if (match(V, m_OneUse(m_c_Or(
636-
m_Value(X),
637-
m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
638-
m_APInt(ShAmt2)))))) ||
639-
match(V, m_OneUse(m_Or(m_Value(X),
640-
m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
635+
if (match(V, m_OneUse(m_c_Or(m_Value(X),
636+
m_OneUse(m_ShlOrSelf(m_OneUse(m_ZExt(m_OneUse(
637+
m_Instruction(L2)))),
638+
ShAmt2)))))) {
641639
if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
642640
// Avoid Partial chain merge.
643641
return false;
@@ -646,11 +644,10 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
646644

647645
// Check if the pattern has loads
648646
LoadInst *LI1 = LOps.Root;
649-
const APInt *ShAmt1 = LOps.Shift;
647+
uint64_t ShAmt1 = LOps.Shift;
650648
if (LOps.FoundRoot == false &&
651-
(match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
652-
match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
653-
m_APInt(ShAmt1)))))) {
649+
match(X, m_OneUse(m_ShlOrSelf(
650+
m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), ShAmt1)))) {
654651
LI1 = dyn_cast<LoadInst>(L1);
655652
}
656653
LoadInst *LI2 = dyn_cast<LoadInst>(L2);
@@ -726,13 +723,6 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
726723
if (IsBigEndian)
727724
std::swap(ShAmt1, ShAmt2);
728725

729-
// Find Shifts values.
730-
uint64_t Shift1 = 0, Shift2 = 0;
731-
if (ShAmt1)
732-
Shift1 = ShAmt1->getZExtValue();
733-
if (ShAmt2)
734-
Shift2 = ShAmt2->getZExtValue();
735-
736726
// First load is always LI1. This is where we put the new load.
737727
// Use the merged load size available from LI1 for forward loads.
738728
if (LOps.FoundRoot) {
@@ -747,7 +737,7 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
747737
uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
748738
uint64_t PrevSize =
749739
DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
750-
if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
740+
if ((ShAmt2 - ShAmt1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
751741
return false;
752742

753743
// Update LOps
@@ -824,7 +814,7 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
824814
// Check if shift needed. We need to shift with the amount of load1
825815
// shift if not zero.
826816
if (LOps.Shift)
827-
NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
817+
NewOp = Builder.CreateShl(NewOp, LOps.Shift);
828818
I.replaceAllUsesWith(NewOp);
829819

830820
return true;
@@ -860,11 +850,9 @@ static std::optional<PartStore> matchPartStore(Instruction &I,
860850
return std::nullopt;
861851

862852
uint64_t ValWidth = StoredTy->getPrimitiveSizeInBits();
863-
uint64_t ValOffset = 0;
853+
uint64_t ValOffset;
864854
Value *Val;
865-
if (!match(StoredVal, m_CombineOr(m_Trunc(m_LShr(m_Value(Val),
866-
m_ConstantInt(ValOffset))),
867-
m_Trunc(m_Value(Val)))))
855+
if (!match(StoredVal, m_Trunc(m_LShrOrSelf(m_Value(Val), ValOffset))))
868856
return std::nullopt;
869857

870858
Value *Ptr = Store->getPointerOperand();

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3605,16 +3605,11 @@ static bool matchSubIntegerPackFromVector(Value *V, Value *&Vec,
36053605
int64_t &VecOffset,
36063606
SmallBitVector &Mask,
36073607
const DataLayout &DL) {
3608-
static const auto m_ConstShlOrSelf = [](const auto &Base, uint64_t &ShlAmt) {
3609-
ShlAmt = 0;
3610-
return m_CombineOr(m_Shl(Base, m_ConstantInt(ShlAmt)), Base);
3611-
};
3612-
36133608
// First try to match extractelement -> zext -> shl
36143609
uint64_t VecIdx, ShlAmt;
3615-
if (match(V, m_ConstShlOrSelf(m_ZExtOrSelf(m_ExtractElt(
3616-
m_Value(Vec), m_ConstantInt(VecIdx))),
3617-
ShlAmt))) {
3610+
if (match(V, m_ShlOrSelf(m_ZExtOrSelf(m_ExtractElt(m_Value(Vec),
3611+
m_ConstantInt(VecIdx))),
3612+
ShlAmt))) {
36183613
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
36193614
if (!VecTy)
36203615
return false;

llvm/unittests/IR/PatternMatch.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,4 +2621,40 @@ TEST_F(PatternMatchTest, PtrAdd) {
26212621
EXPECT_FALSE(match(OtherGEP, m_PtrAdd(m_Value(A), m_Value(B))));
26222622
}
26232623

2624+
TEST_F(PatternMatchTest, ShiftOrSelf) {
2625+
Type *I64Ty = Type::getInt64Ty(Ctx);
2626+
Constant *LHS = ConstantInt::get(I64Ty, 7);
2627+
Constant *ShAmt = ConstantInt::get(I64Ty, 16);
2628+
Value *Shl = IRB.CreateShl(LHS, ShAmt);
2629+
Value *LShr = IRB.CreateLShr(LHS, ShAmt);
2630+
Value *AShr = IRB.CreateAShr(LHS, ShAmt);
2631+
Value *Add = IRB.CreateAdd(LHS, LHS);
2632+
2633+
uint64_t ShAmtC;
2634+
Value *A;
2635+
EXPECT_TRUE(match(Shl, m_ShlOrSelf(m_Value(A), ShAmtC)));
2636+
EXPECT_EQ(A, LHS);
2637+
EXPECT_EQ(ShAmtC, 16U);
2638+
2639+
EXPECT_TRUE(match(Add, m_ShlOrSelf(m_Value(A), ShAmtC)));
2640+
EXPECT_EQ(A, Add);
2641+
EXPECT_EQ(ShAmtC, 0U);
2642+
2643+
EXPECT_TRUE(match(LShr, m_LShrOrSelf(m_Value(A), ShAmtC)));
2644+
EXPECT_EQ(A, LHS);
2645+
EXPECT_EQ(ShAmtC, 16U);
2646+
2647+
EXPECT_TRUE(match(Add, m_LShrOrSelf(m_Value(A), ShAmtC)));
2648+
EXPECT_EQ(A, Add);
2649+
EXPECT_EQ(ShAmtC, 0U);
2650+
2651+
EXPECT_TRUE(match(AShr, m_AShrOrSelf(m_Value(A), ShAmtC)));
2652+
EXPECT_EQ(A, LHS);
2653+
EXPECT_EQ(ShAmtC, 16U);
2654+
2655+
EXPECT_TRUE(match(Add, m_AShrOrSelf(m_Value(A), ShAmtC)));
2656+
EXPECT_EQ(A, Add);
2657+
EXPECT_EQ(ShAmtC, 0U);
2658+
}
2659+
26242660
} // anonymous namespace.

0 commit comments

Comments
 (0)