-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[InstCombine] Add support for GEPs in simplifyNonNullOperand
#128365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-transforms Author: Yingwei Zheng (dtcxzyw) ChangesAlive2: https://alive2.llvm.org/ce/z/2KE8zG Full diff: https://github.com/llvm/llvm-project/pull/128365.diff 8 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 54f777ab20a7a..63f2fd0a733ce 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3996,8 +3996,12 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
if (V->getType()->isPointerTy()) {
// Simplify the nonnull operand if the parameter is known to be nonnull.
// Otherwise, try to infer nonnull for it.
- if (Call.paramHasNonNullAttr(ArgNo, /*AllowUndefOrPoison=*/true)) {
- if (Value *Res = simplifyNonNullOperand(V)) {
+ bool HasDereferenceable = Call.getParamDereferenceableBytes(ArgNo) > 0;
+ if (Call.paramHasAttr(ArgNo, Attribute::NonNull) ||
+ (HasDereferenceable &&
+ !NullPointerIsDefined(Call.getFunction(),
+ V->getType()->getPointerAddressSpace()))) {
+ if (Value *Res = simplifyNonNullOperand(V, HasDereferenceable)) {
replaceOperand(Call, ArgNo, Res);
Changed = true;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 71c80d4c401f8..5b2af39e69f2c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -457,7 +457,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
/// Simplify \p V given that it is known to be non-null.
/// Returns the simplified value if possible, otherwise returns nullptr.
- Value *simplifyNonNullOperand(Value *V);
+ /// If \p HasDereferenceable is true, the simplification will not perform
+ /// same object checks.
+ Value *simplifyNonNullOperand(Value *V, bool HasDereferenceable,
+ unsigned Depth = 0);
public:
/// Create and insert the idiom we use to indicate a block is unreachable
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 89fc1051b18dc..622884ea1eb46 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -982,8 +982,9 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) {
return false;
}
-/// TODO: Recursively simplify nonnull value to handle one-use inbounds GEPs.
-Value *InstCombinerImpl::simplifyNonNullOperand(Value *V) {
+Value *InstCombinerImpl::simplifyNonNullOperand(Value *V,
+ bool HasDereferenceable,
+ unsigned Depth) {
if (auto *Sel = dyn_cast<SelectInst>(V)) {
if (isa<ConstantPointerNull>(Sel->getOperand(1)))
return Sel->getOperand(2);
@@ -992,6 +993,23 @@ Value *InstCombinerImpl::simplifyNonNullOperand(Value *V) {
return Sel->getOperand(1);
}
+ if (!V->hasOneUse())
+ return nullptr;
+
+ if (Depth == 1)
+ return nullptr;
+
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(V)) {
+ if (HasDereferenceable || GEP->isInBounds()) {
+ if (auto *Res = simplifyNonNullOperand(GEP->getPointerOperand(),
+ HasDereferenceable, Depth + 1)) {
+ replaceOperand(*GEP, 0, Res);
+ addToWorklist(GEP);
+ return nullptr;
+ }
+ }
+ }
+
return nullptr;
}
@@ -1076,7 +1094,7 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
}
if (!NullPointerIsDefined(LI.getFunction(), LI.getPointerAddressSpace()))
- if (Value *V = simplifyNonNullOperand(Op))
+ if (Value *V = simplifyNonNullOperand(Op, /*HasDereferenceable=*/true))
return replaceOperand(LI, 0, V);
return nullptr;
@@ -1444,7 +1462,7 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
return eraseInstFromFunction(SI);
if (!NullPointerIsDefined(SI.getFunction(), SI.getPointerAddressSpace()))
- if (Value *V = simplifyNonNullOperand(Ptr))
+ if (Value *V = simplifyNonNullOperand(Ptr, /*HasDereferenceable=*/true))
return replaceOperand(SI, 1, V);
return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index b7748f59a0cfc..81b057c10b484 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3593,10 +3593,12 @@ Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) {
Function *F = RI.getFunction();
Type *RetTy = RetVal->getType();
if (RetTy->isPointerTy()) {
+ bool HasDereferenceable =
+ F->getAttributes().getRetDereferenceableBytes() > 0;
if (F->hasRetAttribute(Attribute::NonNull) ||
- (F->getAttributes().getRetDereferenceableBytes() > 0 &&
+ (HasDereferenceable &&
!NullPointerIsDefined(F, RetTy->getPointerAddressSpace()))) {
- if (Value *V = simplifyNonNullOperand(RetVal))
+ if (Value *V = simplifyNonNullOperand(RetVal, HasDereferenceable))
return replaceOperand(RI, 0, V);
}
}
diff --git a/llvm/test/Transforms/InstCombine/load.ll b/llvm/test/Transforms/InstCombine/load.ll
index 6c087aa87845f..a5ad1e0c21526 100644
--- a/llvm/test/Transforms/InstCombine/load.ll
+++ b/llvm/test/Transforms/InstCombine/load.ll
@@ -439,3 +439,15 @@ define i4 @test_vector_load_i4_non_byte_sized() {
%res0 = load i4, ptr %ptr0, align 1
ret i4 %res0
}
+
+define i32 @load_select_with_null_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @load_select_with_null_gep(
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[SEL:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT: [[RES:%.*]] = load i32, ptr [[GEP]], align 4
+; CHECK-NEXT: ret i32 [[RES]]
+;
+ %sel = select i1 %cond, ptr %p, ptr null
+ %gep = getelementptr i8, ptr %sel, i64 %off
+ %res = load i32, ptr %gep, align 4
+ ret i32 %res
+}
diff --git a/llvm/test/Transforms/InstCombine/nonnull-select.ll b/llvm/test/Transforms/InstCombine/nonnull-select.ll
index cc000b4c88164..929919f9c42c7 100644
--- a/llvm/test/Transforms/InstCombine/nonnull-select.ll
+++ b/llvm/test/Transforms/InstCombine/nonnull-select.ll
@@ -86,4 +86,102 @@ define void @nonnull_noundef_call2(i1 %cond, ptr %p) {
ret void
}
+define void @nonnull_call_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_call_gep(
+; CHECK-NEXT: [[PTR:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[OFF:%.*]]
+; CHECK-NEXT: call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT: ret void
+;
+ %ptr = select i1 %cond, ptr null, ptr %p
+ %gep = getelementptr i8, ptr %ptr, i64 %off
+ call void @f(ptr nonnull %gep)
+ ret void
+}
+
+define void @nonnull_call_gep_multiuse(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_call_gep_multiuse(
+; CHECK-NEXT: [[PTR:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[OFF:%.*]]
+; CHECK-NEXT: call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT: call void @f(ptr [[GEP]])
+; CHECK-NEXT: ret void
+;
+ %ptr = select i1 %cond, ptr null, ptr %p
+ %gep = getelementptr inbounds i8, ptr %ptr, i64 %off
+ call void @f(ptr nonnull %gep)
+ call void @f(ptr %gep)
+ ret void
+}
+
+define void @all_nonnull_call_gep_multiuse(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @all_nonnull_call_gep_multiuse(
+; CHECK-NEXT: [[PTR:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[OFF:%.*]]
+; CHECK-NEXT: call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT: call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT: ret void
+;
+ %ptr = select i1 %cond, ptr null, ptr %p
+ %gep = getelementptr inbounds i8, ptr %ptr, i64 %off
+ call void @f(ptr nonnull %gep)
+ call void @f(ptr nonnull %gep)
+ ret void
+}
+
+define void @nonnull_call_gep_inbounds(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_call_gep_inbounds(
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT: call void @f(ptr nonnull [[GEP]])
+; CHECK-NEXT: ret void
+;
+ %ptr = select i1 %cond, ptr null, ptr %p
+ %gep = getelementptr inbounds i8, ptr %ptr, i64 %off
+ call void @f(ptr nonnull %gep)
+ ret void
+}
+
+define void @nonnull_dereferenceable_call_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_dereferenceable_call_gep(
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT: call void @f(ptr dereferenceable(1) [[GEP]])
+; CHECK-NEXT: ret void
+;
+ %ptr = select i1 %cond, ptr null, ptr %p
+ %gep = getelementptr i8, ptr %ptr, i64 %off
+ call void @f(ptr dereferenceable(1) %gep)
+ ret void
+}
+
+define nonnull ptr @nonnull_ret_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_ret_gep(
+; CHECK-NEXT: [[PTR:%.*]] = select i1 [[COND:%.*]], ptr null, ptr [[P:%.*]]
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[OFF:%.*]]
+; CHECK-NEXT: ret ptr [[GEP]]
+;
+ %ptr = select i1 %cond, ptr null, ptr %p
+ %gep = getelementptr i8, ptr %ptr, i64 %off
+ ret ptr %gep
+}
+
+define nonnull ptr @nonnull_ret_gep_inbounds(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_ret_gep_inbounds(
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT: ret ptr [[GEP]]
+;
+ %ptr = select i1 %cond, ptr null, ptr %p
+ %gep = getelementptr inbounds i8, ptr %ptr, i64 %off
+ ret ptr %gep
+}
+
+define dereferenceable(1) ptr @nonnull_dereferenceable_ret_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @nonnull_dereferenceable_ret_gep(
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[PTR:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT: ret ptr [[GEP]]
+;
+ %ptr = select i1 %cond, ptr null, ptr %p
+ %gep = getelementptr i8, ptr %ptr, i64 %off
+ ret ptr %gep
+}
+
declare void @f(ptr)
diff --git a/llvm/test/Transforms/InstCombine/store.ll b/llvm/test/Transforms/InstCombine/store.ll
index 0a2b0a5ee7987..daa40da1828b5 100644
--- a/llvm/test/Transforms/InstCombine/store.ll
+++ b/llvm/test/Transforms/InstCombine/store.ll
@@ -387,6 +387,18 @@ define void @store_select_with_unknown(i1 %cond, ptr %p, ptr %p2) {
ret void
}
+define void @store_select_with_null_gep(i1 %cond, ptr %p, i64 %off) {
+; CHECK-LABEL: @store_select_with_null_gep(
+; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr [[SEL:%.*]], i64 [[OFF:%.*]]
+; CHECK-NEXT: store i32 0, ptr [[GEP]], align 4
+; CHECK-NEXT: ret void
+;
+ %sel = select i1 %cond, ptr %p, ptr null
+ %gep = getelementptr i8, ptr %sel, i64 %off
+ store i32 0, ptr %gep, align 4
+ ret void
+}
+
!0 = !{!4, !4, i64 0}
!1 = !{!"omnipotent char", !2}
!2 = !{!"Simple C/C++ TBAA"}
diff --git a/llvm/test/Transforms/PhaseOrdering/memset-combine.ll b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
index d1de11258ed91..b1a5881bcaa9c 100644
--- a/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
+++ b/llvm/test/Transforms/PhaseOrdering/memset-combine.ll
@@ -6,10 +6,8 @@
define void @merge_memset(ptr %p, i1 %cond) {
; CHECK-LABEL: define void @merge_memset(
; CHECK-SAME: ptr [[P:%.*]], i1 [[COND:%.*]]) {
-; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], ptr null, ptr [[P]]
-; CHECK-NEXT: tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(4096) [[P]], i8 0, i64 4096, i1 false)
-; CHECK-NEXT: [[OFF:%.*]] = getelementptr inbounds nuw i8, ptr [[SEL]], i64 4096
-; CHECK-NEXT: tail call void @llvm.memset.p0.i64(ptr noundef nonnull align 1 dereferenceable(768) [[OFF]], i8 0, i64 768, i1 false)
+; CHECK-NEXT: [[OFF:%.*]] = getelementptr inbounds nuw i8, ptr [[P]], i64 4096
+; CHECK-NEXT: call void @llvm.memset.p0.i64(ptr align 1 [[P]], i8 0, i64 4864, i1 false)
; CHECK-NEXT: ret void
;
%sel = select i1 %cond, ptr null, ptr %p
|
nikic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| // Otherwise, try to infer nonnull for it. | ||
| if (Call.paramHasNonNullAttr(ArgNo, /*AllowUndefOrPoison=*/true)) { | ||
| if (Value *Res = simplifyNonNullOperand(V)) { | ||
| bool HasDereferenceable = Call.getParamDereferenceableBytes(ArgNo) > 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, could add a HasDereferenceable out parameter to paramHasNonNullAttr.
|
Google is seeing some miscompiles that track back to this PR. I'm working on a sharable test case, but just wanted to put that out there in case someone else sees something similar. |
I got an automatically reduced version of the code @Sterling-Augustine is talking about here. I'm not quite sure the interesting bits weren't distorted during reduction, but maybe this will help figuring out where the problem is: https://gcc.godbolt.org/z/YYcvG69Tf |
I found a UB in the original C++ code, so maybe there's no miscompilation, just wrong user code. Nevertheless, please double-check if the input I provided is handled as intended. |
On a second thought, the original code may have been correct. |
|
@dtcxzyw we (at google) have analyzed the code that is now broken by this change and decided it was correct. Can you please take a look at the reduced version @alexfh posted in #128365 (comment) and let us know if anything sticks out to you? |
|
As far as I can tell, the transform on the provided IR is correct. Did you check your code under |
It turned out to be even more subtle: the optimization removed the last access to a global variable, which triggered an elimination of another global variable together with its dynamic initializer, which served the purpose of registering a certain type to a global registry. So ultimately, no issue with this commit. Thanks for checking! |
Alive2: https://alive2.llvm.org/ce/z/2KE8zG