Skip to content

[PatternMatch] Add m_[Shift]OrSelf matchers. #152924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,45 @@ inline BinaryOp_match<LHS, RHS, Instruction::AShr> m_AShr(const LHS &L,
return BinaryOp_match<LHS, RHS, Instruction::AShr>(L, R);
}

template <typename LHS_t, unsigned Opcode> struct ShiftLike_match {
LHS_t L;
uint64_t &R;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm late to the party.

Since this PR got reverted due to not being NFC, perhaps this PR caused changes because m_APInt matches splat values, but m_ConstantInt does not.

So maybe making R a const APInt *& rather than a uint64_t & would ensure this PR doesn't change any existing behaviour?


ShiftLike_match(const LHS_t &LHS, uint64_t &RHS) : L(LHS), R(RHS) {}

template <typename OpTy> bool match(OpTy *V) const {
if (auto *Op = dyn_cast<BinaryOperator>(V)) {
if (Op->getOpcode() == Opcode)
return m_ConstantInt(R).match(Op->getOperand(1)) &&
L.match(Op->getOperand(0));
}
// Interpreted as shiftop V, 0
R = 0;
return L.match(V);
}
};

/// Matches shl L, ConstShAmt or L itself.
template <typename LHS>
inline ShiftLike_match<LHS, Instruction::Shl> m_ShlOrSelf(const LHS &L,
uint64_t &R) {
return ShiftLike_match<LHS, Instruction::Shl>(L, R);
}

/// Matches lshr L, ConstShAmt or L itself.
template <typename LHS>
inline ShiftLike_match<LHS, Instruction::LShr> m_LShrOrSelf(const LHS &L,
uint64_t &R) {
return ShiftLike_match<LHS, Instruction::LShr>(L, R);
}

/// Matches ashr L, ConstShAmt or L itself.
template <typename LHS>
inline ShiftLike_match<LHS, Instruction::AShr> m_AShrOrSelf(const LHS &L,
uint64_t &R) {
return ShiftLike_match<LHS, Instruction::AShr>(L, R);
}

template <typename LHS_t, typename RHS_t, unsigned Opcode,
unsigned WrapFlags = 0, bool Commutable = false>
struct OverflowingBinaryOp_match {
Expand Down
38 changes: 13 additions & 25 deletions llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ struct LoadOps {
LoadInst *RootInsert = nullptr;
bool FoundRoot = false;
uint64_t LoadSize = 0;
const APInt *Shift = nullptr;
uint64_t Shift = 0;
Type *ZextType;
AAMDNodes AATags;
};
Expand All @@ -627,17 +627,15 @@ struct LoadOps {
// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
AliasAnalysis &AA) {
const APInt *ShAmt2 = nullptr;
uint64_t ShAmt2;
Value *X;
Instruction *L1, *L2;

// Go to the last node with loads.
if (match(V, m_OneUse(m_c_Or(
m_Value(X),
m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
m_APInt(ShAmt2)))))) ||
match(V, m_OneUse(m_Or(m_Value(X),
m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
if (match(V, m_OneUse(m_c_Or(m_Value(X),
m_OneUse(m_ShlOrSelf(m_OneUse(m_ZExt(m_OneUse(
m_Instruction(L2)))),
ShAmt2)))))) {
if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
// Avoid Partial chain merge.
return false;
Expand All @@ -646,11 +644,10 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,

// Check if the pattern has loads
LoadInst *LI1 = LOps.Root;
const APInt *ShAmt1 = LOps.Shift;
uint64_t ShAmt1 = LOps.Shift;
if (LOps.FoundRoot == false &&
(match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
m_APInt(ShAmt1)))))) {
match(X, m_OneUse(m_ShlOrSelf(
m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), ShAmt1)))) {
LI1 = dyn_cast<LoadInst>(L1);
}
LoadInst *LI2 = dyn_cast<LoadInst>(L2);
Expand Down Expand Up @@ -726,13 +723,6 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
if (IsBigEndian)
std::swap(ShAmt1, ShAmt2);

// Find Shifts values.
uint64_t Shift1 = 0, Shift2 = 0;
if (ShAmt1)
Shift1 = ShAmt1->getZExtValue();
if (ShAmt2)
Shift2 = ShAmt2->getZExtValue();

// First load is always LI1. This is where we put the new load.
// Use the merged load size available from LI1 for forward loads.
if (LOps.FoundRoot) {
Expand All @@ -747,7 +737,7 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
uint64_t PrevSize =
DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
if ((ShAmt2 - ShAmt1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
return false;

// Update LOps
Expand Down Expand Up @@ -824,7 +814,7 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
// Check if shift needed. We need to shift with the amount of load1
// shift if not zero.
if (LOps.Shift)
NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
NewOp = Builder.CreateShl(NewOp, LOps.Shift);
I.replaceAllUsesWith(NewOp);

return true;
Expand Down Expand Up @@ -860,11 +850,9 @@ static std::optional<PartStore> matchPartStore(Instruction &I,
return std::nullopt;

uint64_t ValWidth = StoredTy->getPrimitiveSizeInBits();
uint64_t ValOffset = 0;
uint64_t ValOffset;
Value *Val;
if (!match(StoredVal, m_CombineOr(m_Trunc(m_LShr(m_Value(Val),
m_ConstantInt(ValOffset))),
m_Trunc(m_Value(Val)))))
if (!match(StoredVal, m_Trunc(m_LShrOrSelf(m_Value(Val), ValOffset))))
return std::nullopt;

Value *Ptr = Store->getPointerOperand();
Expand Down
11 changes: 3 additions & 8 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3605,16 +3605,11 @@ static bool matchSubIntegerPackFromVector(Value *V, Value *&Vec,
int64_t &VecOffset,
SmallBitVector &Mask,
const DataLayout &DL) {
static const auto m_ConstShlOrSelf = [](const auto &Base, uint64_t &ShlAmt) {
ShlAmt = 0;
return m_CombineOr(m_Shl(Base, m_ConstantInt(ShlAmt)), Base);
};

// First try to match extractelement -> zext -> shl
uint64_t VecIdx, ShlAmt;
if (match(V, m_ConstShlOrSelf(m_ZExtOrSelf(m_ExtractElt(
m_Value(Vec), m_ConstantInt(VecIdx))),
ShlAmt))) {
if (match(V, m_ShlOrSelf(m_ZExtOrSelf(m_ExtractElt(m_Value(Vec),
m_ConstantInt(VecIdx))),
ShlAmt))) {
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
if (!VecTy)
return false;
Expand Down
36 changes: 36 additions & 0 deletions llvm/unittests/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2621,4 +2621,40 @@ TEST_F(PatternMatchTest, PtrAdd) {
EXPECT_FALSE(match(OtherGEP, m_PtrAdd(m_Value(A), m_Value(B))));
}

TEST_F(PatternMatchTest, ShiftOrSelf) {
Type *I64Ty = Type::getInt64Ty(Ctx);
Constant *LHS = ConstantInt::get(I64Ty, 7);
Constant *ShAmt = ConstantInt::get(I64Ty, 16);
Value *Shl = IRB.CreateShl(LHS, ShAmt);
Value *LShr = IRB.CreateLShr(LHS, ShAmt);
Value *AShr = IRB.CreateAShr(LHS, ShAmt);
Value *Add = IRB.CreateAdd(LHS, LHS);

uint64_t ShAmtC;
Value *A;
EXPECT_TRUE(match(Shl, m_ShlOrSelf(m_Value(A), ShAmtC)));
EXPECT_EQ(A, LHS);
EXPECT_EQ(ShAmtC, 16U);

EXPECT_TRUE(match(Add, m_ShlOrSelf(m_Value(A), ShAmtC)));
EXPECT_EQ(A, Add);
EXPECT_EQ(ShAmtC, 0U);

EXPECT_TRUE(match(LShr, m_LShrOrSelf(m_Value(A), ShAmtC)));
EXPECT_EQ(A, LHS);
EXPECT_EQ(ShAmtC, 16U);

EXPECT_TRUE(match(Add, m_LShrOrSelf(m_Value(A), ShAmtC)));
EXPECT_EQ(A, Add);
EXPECT_EQ(ShAmtC, 0U);

EXPECT_TRUE(match(AShr, m_AShrOrSelf(m_Value(A), ShAmtC)));
EXPECT_EQ(A, LHS);
EXPECT_EQ(ShAmtC, 16U);

EXPECT_TRUE(match(Add, m_AShrOrSelf(m_Value(A), ShAmtC)));
EXPECT_EQ(A, Add);
EXPECT_EQ(ShAmtC, 0U);
}

} // anonymous namespace.