Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 37 additions & 71 deletions llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ static cl::opt<unsigned> StrNCmpInlineThreshold(
cl::desc("The maximum length of a constant string for a builtin string cmp "
"call eligible for inlining. The default value is 3."));

static cl::opt<unsigned>
MemChrInlineThreshold("memchr-inline-threshold", cl::init(3), cl::Hidden,
cl::desc("The maximum length of a constant string to "
"inline a memchr call."));
static cl::opt<unsigned> MemChrInlineThreshold(
"memchr-inline-threshold", cl::init(6), cl::Hidden,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of increasing this threshold to 6?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially set the threshold to 6 to match the test cases in memchr.ll, which test the inlining optimization for lengths up to 5 bytes. Would you recommend keeping it at 3 to be more conservative?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest keeping it at 3 if you do not have some performance data on llvm-test-suite/SPEC.

cl::desc("Size threshold for inlining memchr calls"));

/// Match a pattern for a bitwise funnel/rotate operation that partially guards
/// against undefined behavior by branching around the funnel-shift/rotation
Expand Down Expand Up @@ -1106,79 +1105,46 @@ void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
}
}

/// Convert memchr with a small constant string into a switch
static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU,
const DataLayout &DL) {
if (isa<Constant>(Call->getArgOperand(1)))
Value *Ptr = Call->getArgOperand(0);
Value *Val = Call->getArgOperand(1);
Value *Len = Call->getArgOperand(2);

// If length is not a constant, we can't do the optimization
auto *LenC = dyn_cast<ConstantInt>(Len);
if (!LenC)
return false;

StringRef Str;
Value *Base = Call->getArgOperand(0);
if (!getConstantStringInfo(Base, Str, /*TrimAtNul=*/false))
return false;

uint64_t N = Str.size();
if (auto *ConstInt = dyn_cast<ConstantInt>(Call->getArgOperand(2))) {
uint64_t Val = ConstInt->getZExtValue();
// Ignore the case that n is larger than the size of string.
if (Val > N)
return false;
N = Val;
} else
return false;

if (N > MemChrInlineThreshold)
return false;

BasicBlock *BB = Call->getParent();
BasicBlock *BBNext = SplitBlock(BB, Call, DTU);
IRBuilder<> IRB(BB);
IntegerType *ByteTy = IRB.getInt8Ty();
BB->getTerminator()->eraseFromParent();
SwitchInst *SI = IRB.CreateSwitch(
IRB.CreateTrunc(Call->getArgOperand(1), ByteTy), BBNext, N);
Type *IndexTy = DL.getIndexType(Call->getType());
SmallVector<DominatorTree::UpdateType, 8> Updates;

BasicBlock *BBSuccess = BasicBlock::Create(
Call->getContext(), "memchr.success", BB->getParent(), BBNext);
IRB.SetInsertPoint(BBSuccess);
PHINode *IndexPHI = IRB.CreatePHI(IndexTy, N, "memchr.idx");
Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Base, IndexPHI);
IRB.CreateBr(BBNext);
if (DTU)
Updates.push_back({DominatorTree::Insert, BBSuccess, BBNext});

SmallPtrSet<ConstantInt *, 4> Cases;
for (uint64_t I = 0; I < N; ++I) {
ConstantInt *CaseVal = ConstantInt::get(ByteTy, Str[I]);
if (!Cases.insert(CaseVal).second)
continue;

BasicBlock *BBCase = BasicBlock::Create(Call->getContext(), "memchr.case",
BB->getParent(), BBSuccess);
SI->addCase(CaseVal, BBCase);
IRB.SetInsertPoint(BBCase);
IndexPHI->addIncoming(ConstantInt::get(IndexTy, I), BBCase);
IRB.CreateBr(BBSuccess);
if (DTU) {
Updates.push_back({DominatorTree::Insert, BB, BBCase});
Updates.push_back({DominatorTree::Insert, BBCase, BBSuccess});

uint64_t Length = LenC->getZExtValue();

// Check if this is a small memchr we should inline
if (Length <= MemChrInlineThreshold) {
IRBuilder<> IRB(Call);

// Truncate the search value to i8
Value *ByteVal = IRB.CreateTrunc(Val, IRB.getInt8Ty());

// Initialize result to null
Value *Result = ConstantPointerNull::get(cast<PointerType>(Call->getType()));

// For each byte up to Length
for (unsigned i = 0; i < Length; i++) {
Value *CurPtr = i == 0 ? Ptr :
IRB.CreateGEP(IRB.getInt8Ty(), Ptr,
ConstantInt::get(DL.getIndexType(Call->getType()), i));
Value *CurByte = IRB.CreateLoad(IRB.getInt8Ty(), CurPtr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is dangerous to put these load instructions in a basic block.
Consider the following case:

const char *arr = "a";
return memchr(arr, 'a', 3);

It is well-defined because it returns at the first occurrence.
After this transformation, we transform this call into three loads, which causes out-of-bound
memory access.
See also https://en.cppreference.com/w/c/string/byte/memchr.

To avoid UB, we should transform this call into a if chain.

BTW, your implementation returns the last occurrence of ch :(

Value *CmpRes = IRB.CreateICmpEQ(CurByte, ByteVal);
Result = IRB.CreateSelect(CmpRes, CurPtr, Result);
}

// Replace the call with our expanded version
Call->replaceAllUsesWith(Result);
Call->eraseFromParent();
return true;
}

PHINode *PHI =
PHINode::Create(Call->getType(), 2, Call->getName(), BBNext->begin());
PHI->addIncoming(Constant::getNullValue(Call->getType()), BB);
PHI->addIncoming(FirstOccursLocation, BBSuccess);

Call->replaceAllUsesWith(PHI);
Call->eraseFromParent();

if (DTU)
DTU->applyUpdates(Updates);

return true;
return false;
}

static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
Expand Down
113 changes: 50 additions & 63 deletions llvm/test/Transforms/AggressiveInstCombine/memchr.ll
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,18 @@ declare ptr @memchr(ptr, i32, i64)
define i1 @test_memchr_null(i32 %x) {
; CHECK-LABEL: define i1 @test_memchr_null(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*]]:
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = trunc i32 [[X]] to i8
; CHECK-NEXT: switch i8 [[TMP0]], label %[[ENTRY_SPLIT:.*]] [
; CHECK-NEXT: i8 48, label %[[MEMCHR_CASE:.*]]
; CHECK-NEXT: i8 49, label %[[MEMCHR_CASE1:.*]]
; CHECK-NEXT: i8 0, label %[[MEMCHR_CASE2:.*]]
; CHECK-NEXT: i8 50, label %[[MEMCHR_CASE3:.*]]
; CHECK-NEXT: ]
; CHECK: [[MEMCHR_CASE]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS:.*]]
; CHECK: [[MEMCHR_CASE1]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS]]
; CHECK: [[MEMCHR_CASE2]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS]]
; CHECK: [[MEMCHR_CASE3]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS]]
; CHECK: [[MEMCHR_SUCCESS]]:
; CHECK-NEXT: [[MEMCHR_IDX:%.*]] = phi i64 [ 0, %[[MEMCHR_CASE]] ], [ 1, %[[MEMCHR_CASE1]] ], [ 2, %[[MEMCHR_CASE2]] ], [ 3, %[[MEMCHR_CASE3]] ]
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr @str, i64 [[MEMCHR_IDX]]
; CHECK-NEXT: br label %[[ENTRY_SPLIT]]
; CHECK: [[ENTRY_SPLIT]]:
; CHECK-NEXT: [[MEMCHR4:%.*]] = phi ptr [ null, %[[ENTRY]] ], [ [[TMP1]], %[[MEMCHR_SUCCESS]] ]
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 48, [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], ptr @str, ptr null
; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 49, [[TMP0]]
; CHECK-NEXT: [[TMP4:%.*]] = select i1 [[TMP3]], ptr getelementptr (i8, ptr @str, i64 1), ptr [[TMP2]]
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i8 0, [[TMP0]]
; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[TMP5]], ptr getelementptr (i8, ptr @str, i64 2), ptr [[TMP4]]
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i8 50, [[TMP0]]
; CHECK-NEXT: [[TMP8:%.*]] = select i1 [[TMP7]], ptr getelementptr (i8, ptr @str, i64 3), ptr [[TMP6]]
; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i8 0, [[TMP0]]
; CHECK-NEXT: [[MEMCHR4:%.*]] = select i1 [[TMP9]], ptr getelementptr (i8, ptr @str, i64 4), ptr [[TMP8]]
; CHECK-NEXT: [[ISNULL:%.*]] = icmp eq ptr [[MEMCHR4]], null
; CHECK-NEXT: ret i1 [[ISNULL]]
;
Expand All @@ -43,28 +33,18 @@ entry:
define ptr @test_memchr(i32 %x) {
; CHECK-LABEL: define ptr @test_memchr(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*]]:
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = trunc i32 [[X]] to i8
; CHECK-NEXT: switch i8 [[TMP0]], label %[[ENTRY_SPLIT:.*]] [
; CHECK-NEXT: i8 48, label %[[MEMCHR_CASE:.*]]
; CHECK-NEXT: i8 49, label %[[MEMCHR_CASE1:.*]]
; CHECK-NEXT: i8 0, label %[[MEMCHR_CASE2:.*]]
; CHECK-NEXT: i8 50, label %[[MEMCHR_CASE3:.*]]
; CHECK-NEXT: ]
; CHECK: [[MEMCHR_CASE]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS:.*]]
; CHECK: [[MEMCHR_CASE1]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS]]
; CHECK: [[MEMCHR_CASE2]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS]]
; CHECK: [[MEMCHR_CASE3]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS]]
; CHECK: [[MEMCHR_SUCCESS]]:
; CHECK-NEXT: [[MEMCHR_IDX:%.*]] = phi i64 [ 0, %[[MEMCHR_CASE]] ], [ 1, %[[MEMCHR_CASE1]] ], [ 2, %[[MEMCHR_CASE2]] ], [ 3, %[[MEMCHR_CASE3]] ]
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr @str, i64 [[MEMCHR_IDX]]
; CHECK-NEXT: br label %[[ENTRY_SPLIT]]
; CHECK: [[ENTRY_SPLIT]]:
; CHECK-NEXT: [[MEMCHR4:%.*]] = phi ptr [ null, %[[ENTRY]] ], [ [[TMP1]], %[[MEMCHR_SUCCESS]] ]
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 48, [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], ptr @str, ptr null
; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 49, [[TMP0]]
; CHECK-NEXT: [[TMP4:%.*]] = select i1 [[TMP3]], ptr getelementptr (i8, ptr @str, i64 1), ptr [[TMP2]]
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i8 0, [[TMP0]]
; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[TMP5]], ptr getelementptr (i8, ptr @str, i64 2), ptr [[TMP4]]
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i8 50, [[TMP0]]
; CHECK-NEXT: [[TMP8:%.*]] = select i1 [[TMP7]], ptr getelementptr (i8, ptr @str, i64 3), ptr [[TMP6]]
; CHECK-NEXT: [[TMP9:%.*]] = icmp eq i8 0, [[TMP0]]
; CHECK-NEXT: [[MEMCHR4:%.*]] = select i1 [[TMP9]], ptr getelementptr (i8, ptr @str, i64 4), ptr [[TMP8]]
; CHECK-NEXT: ret ptr [[MEMCHR4]]
;
entry:
Expand All @@ -75,25 +55,14 @@ entry:
define ptr @test_memchr_smaller_n(i32 %x) {
; CHECK-LABEL: define ptr @test_memchr_smaller_n(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*]]:
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = trunc i32 [[X]] to i8
; CHECK-NEXT: switch i8 [[TMP0]], label %[[ENTRY_SPLIT:.*]] [
; CHECK-NEXT: i8 48, label %[[MEMCHR_CASE:.*]]
; CHECK-NEXT: i8 49, label %[[MEMCHR_CASE1:.*]]
; CHECK-NEXT: i8 0, label %[[MEMCHR_CASE2:.*]]
; CHECK-NEXT: ]
; CHECK: [[MEMCHR_CASE]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS:.*]]
; CHECK: [[MEMCHR_CASE1]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS]]
; CHECK: [[MEMCHR_CASE2]]:
; CHECK-NEXT: br label %[[MEMCHR_SUCCESS]]
; CHECK: [[MEMCHR_SUCCESS]]:
; CHECK-NEXT: [[MEMCHR_IDX:%.*]] = phi i64 [ 0, %[[MEMCHR_CASE]] ], [ 1, %[[MEMCHR_CASE1]] ], [ 2, %[[MEMCHR_CASE2]] ]
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr @str, i64 [[MEMCHR_IDX]]
; CHECK-NEXT: br label %[[ENTRY_SPLIT]]
; CHECK: [[ENTRY_SPLIT]]:
; CHECK-NEXT: [[MEMCHR3:%.*]] = phi ptr [ null, %[[ENTRY]] ], [ [[TMP1]], %[[MEMCHR_SUCCESS]] ]
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 48, [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], ptr @str, ptr null
; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 49, [[TMP0]]
; CHECK-NEXT: [[TMP4:%.*]] = select i1 [[TMP3]], ptr getelementptr (i8, ptr @str, i64 1), ptr [[TMP2]]
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i8 0, [[TMP0]]
; CHECK-NEXT: [[MEMCHR3:%.*]] = select i1 [[TMP5]], ptr getelementptr (i8, ptr @str, i64 2), ptr [[TMP4]]
; CHECK-NEXT: ret ptr [[MEMCHR3]]
;
entry:
Expand All @@ -119,7 +88,26 @@ define ptr @test_memchr_non_constant(i32 %x, ptr %str) {
; CHECK-LABEL: define ptr @test_memchr_non_constant(
; CHECK-SAME: i32 [[X:%.*]], ptr [[STR:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[MEMCHR:%.*]] = call ptr @memchr(ptr [[STR]], i32 [[X]], i64 5)
; CHECK-NEXT: [[TMP0:%.*]] = trunc i32 [[X]] to i8
; CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr [[STR]], align 1
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], ptr [[STR]], ptr null
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[STR]], i64 1
; CHECK-NEXT: [[TMP5:%.*]] = load i8, ptr [[TMP4]], align 1
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i8 [[TMP5]], [[TMP0]]
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP6]], ptr [[TMP4]], ptr [[TMP3]]
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[STR]], i64 2
; CHECK-NEXT: [[TMP9:%.*]] = load i8, ptr [[TMP8]], align 1
; CHECK-NEXT: [[TMP10:%.*]] = icmp eq i8 [[TMP9]], [[TMP0]]
; CHECK-NEXT: [[TMP11:%.*]] = select i1 [[TMP10]], ptr [[TMP8]], ptr [[TMP7]]
; CHECK-NEXT: [[TMP12:%.*]] = getelementptr i8, ptr [[STR]], i64 3
; CHECK-NEXT: [[TMP13:%.*]] = load i8, ptr [[TMP12]], align 1
; CHECK-NEXT: [[TMP14:%.*]] = icmp eq i8 [[TMP13]], [[TMP0]]
; CHECK-NEXT: [[TMP15:%.*]] = select i1 [[TMP14]], ptr [[TMP12]], ptr [[TMP11]]
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr i8, ptr [[STR]], i64 4
; CHECK-NEXT: [[TMP17:%.*]] = load i8, ptr [[TMP16]], align 1
; CHECK-NEXT: [[TMP18:%.*]] = icmp eq i8 [[TMP17]], [[TMP0]]
; CHECK-NEXT: [[MEMCHR:%.*]] = select i1 [[TMP18]], ptr [[TMP16]], ptr [[TMP15]]
; CHECK-NEXT: ret ptr [[MEMCHR]]
;
entry:
Expand All @@ -130,8 +118,7 @@ entry:
define ptr @test_memchr_constant_ch() {
; CHECK-LABEL: define ptr @test_memchr_constant_ch() {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[MEMCHR:%.*]] = call ptr @memchr(ptr @str, i32 49, i64 5)
; CHECK-NEXT: ret ptr [[MEMCHR]]
; CHECK-NEXT: ret ptr getelementptr (i8, ptr @str, i64 1)
;
entry:
%memchr = call ptr @memchr(ptr @str, i32 49, i64 5)
Expand Down
23 changes: 23 additions & 0 deletions llvm/test/Transforms/AggressiveInstCombine/test_memchr_small.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=aggressive-instcombine --memchr-inline-threshold=2 < %s | FileCheck %s

declare ptr @memchr(ptr, i32, i64)

define ptr @test_memchr_small(ptr %p, i32 %val) {
; CHECK-LABEL: define ptr @test_memchr_small(
; CHECK-SAME: ptr [[P:%.*]], i32 [[VAL:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = trunc i32 [[VAL]] to i8
; CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr [[P]], align 1
; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], [[TMP0]]
; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], ptr [[P]], ptr null
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[P]], i64 1
; CHECK-NEXT: [[TMP5:%.*]] = load i8, ptr [[TMP4]], align 1
; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i8 [[TMP5]], [[TMP0]]
; CHECK-NEXT: [[TMP7:%.*]] = select i1 [[TMP6]], ptr [[TMP4]], ptr [[TMP3]]
; CHECK-NEXT: ret ptr [[TMP7]]
;
entry:
%res = call ptr @memchr(ptr %p, i32 %val, i64 2)
ret ptr %res
}
Loading