diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 76482ad47c771..27c5d5ca08cd6 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -1327,45 +1327,6 @@ inline BinaryOp_match m_AShr(const LHS &L, return BinaryOp_match(L, R); } -template struct ShiftLike_match { - LHS_t L; - uint64_t &R; - - ShiftLike_match(const LHS_t &LHS, uint64_t &RHS) : L(LHS), R(RHS) {} - - template bool match(OpTy *V) const { - if (auto *Op = dyn_cast(V)) { - if (Op->getOpcode() == Opcode) - return m_ConstantInt(R).match(Op->getOperand(1)) && - L.match(Op->getOperand(0)); - } - // Interpreted as shiftop V, 0 - R = 0; - return L.match(V); - } -}; - -/// Matches shl L, ConstShAmt or L itself. -template -inline ShiftLike_match m_ShlOrSelf(const LHS &L, - uint64_t &R) { - return ShiftLike_match(L, R); -} - -/// Matches lshr L, ConstShAmt or L itself. -template -inline ShiftLike_match m_LShrOrSelf(const LHS &L, - uint64_t &R) { - return ShiftLike_match(L, R); -} - -/// Matches ashr L, ConstShAmt or L itself. -template -inline ShiftLike_match m_AShrOrSelf(const LHS &L, - uint64_t &R) { - return ShiftLike_match(L, R); -} - template struct OverflowingBinaryOp_match { diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index e3c31f96f86d9..40a7f8043034e 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -617,7 +617,7 @@ struct LoadOps { LoadInst *RootInsert = nullptr; bool FoundRoot = false; uint64_t LoadSize = 0; - uint64_t Shift = 0; + const APInt *Shift = nullptr; Type *ZextType; AAMDNodes AATags; }; @@ -627,15 +627,17 @@ struct LoadOps { // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3) static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, AliasAnalysis &AA) { - uint64_t ShAmt2; + const APInt *ShAmt2 = nullptr; Value *X; Instruction *L1, *L2; // Go to the last node with loads. - if (match(V, m_OneUse(m_c_Or(m_Value(X), - m_OneUse(m_ShlOrSelf(m_OneUse(m_ZExt(m_OneUse( - m_Instruction(L2)))), - ShAmt2)))))) { + if (match(V, m_OneUse(m_c_Or( + m_Value(X), + m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))), + m_APInt(ShAmt2)))))) || + match(V, m_OneUse(m_Or(m_Value(X), + m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) { if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot) // Avoid Partial chain merge. return false; @@ -644,10 +646,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, // Check if the pattern has loads LoadInst *LI1 = LOps.Root; - uint64_t ShAmt1 = LOps.Shift; + const APInt *ShAmt1 = LOps.Shift; if (LOps.FoundRoot == false && - match(X, m_OneUse(m_ShlOrSelf( - m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), ShAmt1)))) { + (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) || + match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), + m_APInt(ShAmt1)))))) { LI1 = dyn_cast(L1); } LoadInst *LI2 = dyn_cast(L2); @@ -723,6 +726,13 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, if (IsBigEndian) std::swap(ShAmt1, ShAmt2); + // Find Shifts values. + uint64_t Shift1 = 0, Shift2 = 0; + if (ShAmt1) + Shift1 = ShAmt1->getZExtValue(); + if (ShAmt2) + Shift2 = ShAmt2->getZExtValue(); + // First load is always LI1. This is where we put the new load. // Use the merged load size available from LI1 for forward loads. if (LOps.FoundRoot) { @@ -737,7 +747,7 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1; uint64_t PrevSize = DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1)); - if ((ShAmt2 - ShAmt1) != ShiftDiff || (Offset2 - Offset1) != PrevSize) + if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize) return false; // Update LOps @@ -814,7 +824,7 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, // Check if shift needed. We need to shift with the amount of load1 // shift if not zero. if (LOps.Shift) - NewOp = Builder.CreateShl(NewOp, LOps.Shift); + NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift)); I.replaceAllUsesWith(NewOp); return true; @@ -850,9 +860,11 @@ static std::optional matchPartStore(Instruction &I, return std::nullopt; uint64_t ValWidth = StoredTy->getPrimitiveSizeInBits(); - uint64_t ValOffset; + uint64_t ValOffset = 0; Value *Val; - if (!match(StoredVal, m_Trunc(m_LShrOrSelf(m_Value(Val), ValOffset)))) + if (!match(StoredVal, m_CombineOr(m_Trunc(m_LShr(m_Value(Val), + m_ConstantInt(ValOffset))), + m_Trunc(m_Value(Val))))) return std::nullopt; Value *Ptr = Store->getPointerOperand(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index e9cefa6d73de7..d7971e8e3caea 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3605,11 +3605,16 @@ static bool matchSubIntegerPackFromVector(Value *V, Value *&Vec, int64_t &VecOffset, SmallBitVector &Mask, const DataLayout &DL) { + static const auto m_ConstShlOrSelf = [](const auto &Base, uint64_t &ShlAmt) { + ShlAmt = 0; + return m_CombineOr(m_Shl(Base, m_ConstantInt(ShlAmt)), Base); + }; + // First try to match extractelement -> zext -> shl uint64_t VecIdx, ShlAmt; - if (match(V, m_ShlOrSelf(m_ZExtOrSelf(m_ExtractElt(m_Value(Vec), - m_ConstantInt(VecIdx))), - ShlAmt))) { + if (match(V, m_ConstShlOrSelf(m_ZExtOrSelf(m_ExtractElt( + m_Value(Vec), m_ConstantInt(VecIdx))), + ShlAmt))) { auto *VecTy = dyn_cast(Vec->getType()); if (!VecTy) return false; diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp index 972dac82d3331..bb7cc0802b1df 100644 --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -2621,40 +2621,4 @@ TEST_F(PatternMatchTest, PtrAdd) { EXPECT_FALSE(match(OtherGEP, m_PtrAdd(m_Value(A), m_Value(B)))); } -TEST_F(PatternMatchTest, ShiftOrSelf) { - Type *I64Ty = Type::getInt64Ty(Ctx); - Constant *LHS = ConstantInt::get(I64Ty, 7); - Constant *ShAmt = ConstantInt::get(I64Ty, 16); - Value *Shl = IRB.CreateShl(LHS, ShAmt); - Value *LShr = IRB.CreateLShr(LHS, ShAmt); - Value *AShr = IRB.CreateAShr(LHS, ShAmt); - Value *Add = IRB.CreateAdd(LHS, LHS); - - uint64_t ShAmtC; - Value *A; - EXPECT_TRUE(match(Shl, m_ShlOrSelf(m_Value(A), ShAmtC))); - EXPECT_EQ(A, LHS); - EXPECT_EQ(ShAmtC, 16U); - - EXPECT_TRUE(match(Add, m_ShlOrSelf(m_Value(A), ShAmtC))); - EXPECT_EQ(A, Add); - EXPECT_EQ(ShAmtC, 0U); - - EXPECT_TRUE(match(LShr, m_LShrOrSelf(m_Value(A), ShAmtC))); - EXPECT_EQ(A, LHS); - EXPECT_EQ(ShAmtC, 16U); - - EXPECT_TRUE(match(Add, m_LShrOrSelf(m_Value(A), ShAmtC))); - EXPECT_EQ(A, Add); - EXPECT_EQ(ShAmtC, 0U); - - EXPECT_TRUE(match(AShr, m_AShrOrSelf(m_Value(A), ShAmtC))); - EXPECT_EQ(A, LHS); - EXPECT_EQ(ShAmtC, 16U); - - EXPECT_TRUE(match(Add, m_AShrOrSelf(m_Value(A), ShAmtC))); - EXPECT_EQ(A, Add); - EXPECT_EQ(ShAmtC, 0U); -} - } // anonymous namespace.