diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 2ab652ca258c6..d68be03d4ba7d 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -1328,6 +1328,45 @@ inline BinaryOp_match m_AShr(const LHS &L, return BinaryOp_match(L, R); } +template struct ShiftLike_match { + LHS_t L; + uint64_t &R; + + ShiftLike_match(const LHS_t &LHS, uint64_t &RHS) : L(LHS), R(RHS) {} + + template bool match(OpTy *V) const { + if (auto *Op = dyn_cast(V)) { + if (Op->getOpcode() == Opcode) + return m_ConstantInt(R).match(Op->getOperand(1)) && + L.match(Op->getOperand(0)); + } + // Interpreted as shiftop V, 0 + R = 0; + return L.match(V); + } +}; + +/// Matches shl L, ConstShAmt or L itself (R will be set to zero in this case). +template +inline ShiftLike_match m_ShlOrSelf(const LHS &L, + uint64_t &R) { + return ShiftLike_match(L, R); +} + +/// Matches lshr L, ConstShAmt or L itself (R will be set to zero in this case). +template +inline ShiftLike_match m_LShrOrSelf(const LHS &L, + uint64_t &R) { + return ShiftLike_match(L, R); +} + +/// Matches ashr L, ConstShAmt or L itself (R will be set to zero in this case). +template +inline ShiftLike_match m_AShrOrSelf(const LHS &L, + uint64_t &R) { + return ShiftLike_match(L, R); +} + template struct OverflowingBinaryOp_match { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index ac03593dd5552..a13d3ceb61320 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3640,16 +3640,11 @@ static bool matchSubIntegerPackFromVector(Value *V, Value *&Vec, int64_t &VecOffset, SmallBitVector &Mask, const DataLayout &DL) { - static const auto m_ConstShlOrSelf = [](const auto &Base, uint64_t &ShlAmt) { - ShlAmt = 0; - return m_CombineOr(m_Shl(Base, m_ConstantInt(ShlAmt)), Base); - }; - // First try to match extractelement -> zext -> shl uint64_t VecIdx, ShlAmt; - if (match(V, m_ConstShlOrSelf(m_ZExtOrSelf(m_ExtractElt( - m_Value(Vec), m_ConstantInt(VecIdx))), - ShlAmt))) { + if (match(V, m_ShlOrSelf(m_ZExtOrSelf(m_ExtractElt(m_Value(Vec), + m_ConstantInt(VecIdx))), + ShlAmt))) { auto *VecTy = dyn_cast(Vec->getType()); if (!VecTy) return false; diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp index bb7cc0802b1df..972dac82d3331 100644 --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -2621,4 +2621,40 @@ TEST_F(PatternMatchTest, PtrAdd) { EXPECT_FALSE(match(OtherGEP, m_PtrAdd(m_Value(A), m_Value(B)))); } +TEST_F(PatternMatchTest, ShiftOrSelf) { + Type *I64Ty = Type::getInt64Ty(Ctx); + Constant *LHS = ConstantInt::get(I64Ty, 7); + Constant *ShAmt = ConstantInt::get(I64Ty, 16); + Value *Shl = IRB.CreateShl(LHS, ShAmt); + Value *LShr = IRB.CreateLShr(LHS, ShAmt); + Value *AShr = IRB.CreateAShr(LHS, ShAmt); + Value *Add = IRB.CreateAdd(LHS, LHS); + + uint64_t ShAmtC; + Value *A; + EXPECT_TRUE(match(Shl, m_ShlOrSelf(m_Value(A), ShAmtC))); + EXPECT_EQ(A, LHS); + EXPECT_EQ(ShAmtC, 16U); + + EXPECT_TRUE(match(Add, m_ShlOrSelf(m_Value(A), ShAmtC))); + EXPECT_EQ(A, Add); + EXPECT_EQ(ShAmtC, 0U); + + EXPECT_TRUE(match(LShr, m_LShrOrSelf(m_Value(A), ShAmtC))); + EXPECT_EQ(A, LHS); + EXPECT_EQ(ShAmtC, 16U); + + EXPECT_TRUE(match(Add, m_LShrOrSelf(m_Value(A), ShAmtC))); + EXPECT_EQ(A, Add); + EXPECT_EQ(ShAmtC, 0U); + + EXPECT_TRUE(match(AShr, m_AShrOrSelf(m_Value(A), ShAmtC))); + EXPECT_EQ(A, LHS); + EXPECT_EQ(ShAmtC, 16U); + + EXPECT_TRUE(match(Add, m_AShrOrSelf(m_Value(A), ShAmtC))); + EXPECT_EQ(A, Add); + EXPECT_EQ(ShAmtC, 0U); +} + } // anonymous namespace.