diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 88aef4a368f29..39a65d8b7db99 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -2046,15 +2046,17 @@ struct m_SplatOrPoisonMask { }; template struct PtrAdd_match { + const DataLayout &DL; PointerOpTy PointerOp; OffsetOpTy OffsetOp; - PtrAdd_match(const PointerOpTy &PointerOp, const OffsetOpTy &OffsetOp) - : PointerOp(PointerOp), OffsetOp(OffsetOp) {} + PtrAdd_match(const DataLayout &DL, const PointerOpTy &PointerOp, + const OffsetOpTy &OffsetOp) + : DL(DL), PointerOp(PointerOp), OffsetOp(OffsetOp) {} template bool match(OpTy *V) const { auto *GEP = dyn_cast(V); - return GEP && GEP->getSourceElementType()->isIntegerTy(8) && + return GEP && GEP->getSourceElementType()->isIntegerTy(DL.getByteWidth()) && PointerOp.match(GEP->getPointerOperand()) && OffsetOp.match(GEP->idx_begin()->get()); } @@ -2096,8 +2098,9 @@ inline auto m_GEP(const OperandTypes &...Ops) { /// Matches GEP with i8 source element type template inline PtrAdd_match -m_PtrAdd(const PointerOpTy &PointerOp, const OffsetOpTy &OffsetOp) { - return PtrAdd_match(PointerOp, OffsetOp); +m_PtrAdd(const DataLayout &DL, const PointerOpTy &PointerOp, + const OffsetOpTy &OffsetOp) { + return PtrAdd_match(DL, PointerOp, OffsetOp); } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 2a0a6a2d302b1..ba2aedc5fc714 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -5533,7 +5533,7 @@ static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, Value *Ptr, *X; if ((CastOpc == Instruction::PtrToInt || CastOpc == Instruction::PtrToAddr) && match(Op, - m_PtrAdd(m_Value(Ptr), + m_PtrAdd(Q.DL, m_Value(Ptr), m_Sub(m_Value(X), m_PtrToIntOrAddr(m_Deferred(Ptr))))) && X->getType() == Ty && Ty == Q.DL.getIndexType(Ptr->getType())) return X; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 550dfc57a348b..0728063385bc3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1052,9 +1052,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I, Value *InnerPtr; uint64_t GEPIndex; uint64_t PtrMaskImmediate; - if (match(I, m_Intrinsic( - m_PtrAdd(m_Value(InnerPtr), m_ConstantInt(GEPIndex)), - m_ConstantInt(PtrMaskImmediate)))) { + if (match(I, + m_Intrinsic( + m_PtrAdd(DL, m_Value(InnerPtr), m_ConstantInt(GEPIndex)), + m_ConstantInt(PtrMaskImmediate)))) { LHSKnown = computeKnownBits(InnerPtr, I, Depth + 1); if (!LHSKnown.isZero()) { diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index b158e0f626850..aba0216d6a760 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -2692,7 +2692,7 @@ static Instruction *canonicalizeGEPOfConstGEPI8(GetElementPtrInst &GEP, auto &DL = IC.getDataLayout(); Value *Base; const APInt *C1; - if (!match(Src, m_PtrAdd(m_Value(Base), m_APInt(C1)))) + if (!match(Src, m_PtrAdd(DL, m_Value(Base), m_APInt(C1)))) return nullptr; Value *VarIndex; const APInt *C2; diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 333cbb6ed1384..8362412b742f6 100644 --- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -1055,7 +1055,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { const APInt *BaseOffset; const bool ExtractBase = match(GEP->getPointerOperand(), - m_PtrAdd(m_Value(NewBase), m_APInt(BaseOffset))); + m_PtrAdd(*DL, m_Value(NewBase), m_APInt(BaseOffset))); const int64_t BaseByteOffset = ExtractBase ? BaseOffset->getSExtValue() : 0; diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp index 1142c559c97f8..962f4bc47cc57 100644 --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -2599,26 +2599,40 @@ TEST_F(PatternMatchTest, ConstExpr) { EXPECT_TRUE(match(V, m_ConstantExpr())); } -TEST_F(PatternMatchTest, PtrAdd) { +// PatternMatchTest parametrized by byte width. +class PatternMatchByteParamTest + : public PatternMatchTest, + public ::testing::WithParamInterface { +public: + PatternMatchByteParamTest() { + M->setDataLayout("b:" + std::to_string(GetParam())); + } +}; + +INSTANTIATE_TEST_SUITE_P(ByteWidths, PatternMatchByteParamTest, + ::testing::Values(8, 16, 32)); + +TEST_P(PatternMatchByteParamTest, PtrAdd) { + const DataLayout &DL = M->getDataLayout(); Type *PtrTy = PointerType::getUnqual(Ctx); Type *IdxTy = Type::getInt64Ty(Ctx); Constant *Null = Constant::getNullValue(PtrTy); Constant *Offset = ConstantInt::get(IdxTy, 42); Value *PtrAdd = IRB.CreatePtrAdd(Null, Offset); Value *OtherGEP = IRB.CreateGEP(IdxTy, Null, Offset); - Value *PtrAddConst = - ConstantExpr::getGetElementPtr(Type::getInt8Ty(Ctx), Null, Offset); + Value *PtrAddConst = ConstantExpr::getGetElementPtr( + Type::getIntNTy(Ctx, DL.getByteWidth()), Null, Offset); Value *A, *B; - EXPECT_TRUE(match(PtrAdd, m_PtrAdd(m_Value(A), m_Value(B)))); + EXPECT_TRUE(match(PtrAdd, m_PtrAdd(DL, m_Value(A), m_Value(B)))); EXPECT_EQ(A, Null); EXPECT_EQ(B, Offset); - EXPECT_TRUE(match(PtrAddConst, m_PtrAdd(m_Value(A), m_Value(B)))); + EXPECT_TRUE(match(PtrAddConst, m_PtrAdd(DL, m_Value(A), m_Value(B)))); EXPECT_EQ(A, Null); EXPECT_EQ(B, Offset); - EXPECT_FALSE(match(OtherGEP, m_PtrAdd(m_Value(A), m_Value(B)))); + EXPECT_FALSE(match(OtherGEP, m_PtrAdd(DL, m_Value(A), m_Value(B)))); } TEST_F(PatternMatchTest, ShiftOrSelf) {