Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions llvm/lib/Transforms/IPO/ArgumentPromotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,36 @@ static bool allCallersPassValidPointerForArgument(
});
}

// Try to prove that all Calls to F do not modify the memory pointed to by Arg,
// using alias analysis local to each caller of F.
static bool isArgUnmodifiedByAllCalls(Argument *Arg,
FunctionAnalysisManager &FAM) {
for (User *U : Arg->getParent()->users()) {

// Bail if we find an unexpected (non CallInst) use of the function.
auto *Call = dyn_cast<CallInst>(U);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this checking for CallInst rather than CallBase? This should also work for invokes, right?

And I think then you could also replace it with an assert, as the function only being used as CallBase callee is a precondition of trying the transform in the first place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, I was perhaps a bit trigger happy with committing this. I have put up #110335 to address this.

if (!Call)
return false;

MemoryLocation Loc =
MemoryLocation::getForArgument(Call, Arg->getArgNo(), nullptr);

AAResults &AAR = FAM.getResult<AAManager>(*Call->getFunction());
// Bail as soon as we find a Call where Arg may be modified.
if (isModSet(AAR.getModRefInfo(Call, Loc)))
return false;
}

// All Users are Calls which do not modify the Arg.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment, the one on line 492, the one on line 488/499, and the function name all pretty much say the same thing. Maybe delete the ones in the function body

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed redundant comments in Address review comments 2

return true;
}

/// Determine that this argument is safe to promote, and find the argument
/// parts it can be promoted into.
static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
unsigned MaxElements, bool IsRecursive,
SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec,
FunctionAnalysisManager &FAM) {
// Quick exit for unused arguments
if (Arg->use_empty())
return true;
Expand Down Expand Up @@ -716,10 +741,16 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
return true;

// Okay, now we know that the argument is only used by load instructions, and
// it is safe to unconditionally perform all of them. Use alias analysis to
// check to see if the pointer is guaranteed to not be modified from entry of
// the function to each of the load instructions.
// it is safe to unconditionally perform all of them.

// If we can determine that no call to the Function modifies the memory region
// accessed through Arg, through alias analysis using actual arguments in the
// callers, we know that it is guaranteed to be safe to promote the argument.
if (isArgUnmodifiedByAllCalls(Arg, FAM))
return true;

// Otherwise, use alias analysis to check if the pointer is guaranteed to not
// be modified from entry of the function to each of the load instructions.
for (LoadInst *Load : Loads) {
// Check to see if the load is invalidated from the start of the block to
// the load itself.
Expand Down Expand Up @@ -846,7 +877,8 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
// If we can promote the pointer to its value.
SmallVector<OffsetAndArgPart, 4> ArgParts;

if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
if (findArgParts(PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts,
FAM)) {
SmallVector<Type *, 4> Types;
for (const auto &Pair : ArgParts)
Types.push_back(Pair.second.Ty);
Expand Down
29 changes: 12 additions & 17 deletions llvm/test/Transforms/ArgumentPromotion/actual-arguments.ll
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,14 @@ define internal i32 @test_cannot_promote_3(ptr %p, ptr nocapture readonly %test_
ret i32 %sum
}

; FIXME: We should perform ArgPromotion here!
;
; This is called only by @caller_safe_args_1, from which we can prove that
; %test_c does not alias %p for any Call to the function, so we can promote it.
;
define internal i32 @test_can_promote_1(ptr %p, ptr nocapture readonly %test_c) {
; CHECK-LABEL: define {{[^@]+}}@test_can_promote_1
; CHECK-SAME: (ptr [[P:%.*]], ptr nocapture readonly [[TEST_C:%.*]]) {
; CHECK-NEXT: [[TEST_C_VAL:%.*]] = load i32, ptr [[TEST_C]], align 4
; CHECK-NEXT: [[RES:%.*]] = call i32 @callee(ptr [[P]], i32 [[TEST_C_VAL]])
; CHECK-NEXT: [[LTEST_C:%.*]] = load i32, ptr [[TEST_C]], align 4
; CHECK-NEXT: [[SUM:%.*]] = add i32 [[LTEST_C]], [[RES]]
; CHECK-SAME: (ptr [[P:%.*]], i32 [[TEST_C_0_VAL:%.*]]) {
; CHECK-NEXT: [[RES:%.*]] = call i32 @callee(ptr [[P]], i32 [[TEST_C_0_VAL]])
; CHECK-NEXT: [[SUM:%.*]] = add i32 [[TEST_C_0_VAL]], [[RES]]
; CHECK-NEXT: ret i32 [[SUM]]
;
%res = call i32 @callee(ptr %p, ptr %test_c)
Expand All @@ -91,19 +87,15 @@ define internal i32 @test_can_promote_1(ptr %p, ptr nocapture readonly %test_c)
ret i32 %sum
}

; FIXME: We should perform ArgPromotion here!
;
; This is called by multiple callers (@caller_safe_args_1, @caller_safe_args_2),
; from which we can prove that %test_c does not alias %p for any Call to the
; function, so we can promote it.
;
define internal i32 @test_can_promote_2(ptr %p, ptr nocapture readonly %test_c) {
; CHECK-LABEL: define {{[^@]+}}@test_can_promote_2
; CHECK-SAME: (ptr [[P:%.*]], ptr nocapture readonly [[TEST_C:%.*]]) {
; CHECK-NEXT: [[TEST_C_VAL:%.*]] = load i32, ptr [[TEST_C]], align 4
; CHECK-NEXT: [[RES:%.*]] = call i32 @callee(ptr [[P]], i32 [[TEST_C_VAL]])
; CHECK-NEXT: [[LTEST_C:%.*]] = load i32, ptr [[TEST_C]], align 4
; CHECK-NEXT: [[SUM:%.*]] = add i32 [[LTEST_C]], [[RES]]
; CHECK-SAME: (ptr [[P:%.*]], i32 [[TEST_C_0_VAL:%.*]]) {
; CHECK-NEXT: [[RES:%.*]] = call i32 @callee(ptr [[P]], i32 [[TEST_C_0_VAL]])
; CHECK-NEXT: [[SUM:%.*]] = add i32 [[TEST_C_0_VAL]], [[RES]]
; CHECK-NEXT: ret i32 [[SUM]]
;
%res = call i32 @callee(ptr %p, ptr %test_c)
Expand Down Expand Up @@ -186,8 +178,10 @@ define i32 @caller_safe_args_1(i64 %n) {
; CHECK-NEXT: [[CALLER_C:%.*]] = alloca i32, align 4
; CHECK-NEXT: store i32 5, ptr [[CALLER_C]], align 4
; CHECK-NEXT: [[RES1:%.*]] = call i32 @test_cannot_promote_3(ptr [[P]], ptr [[CALLER_C]])
; CHECK-NEXT: [[RES2:%.*]] = call i32 @test_can_promote_1(ptr [[P]], ptr [[CALLER_C]])
; CHECK-NEXT: [[RES3:%.*]] = call i32 @test_can_promote_2(ptr [[P]], ptr [[CALLER_C]])
; CHECK-NEXT: [[CALLER_C_VAL:%.*]] = load i32, ptr [[CALLER_C]], align 4
; CHECK-NEXT: [[RES2:%.*]] = call i32 @test_can_promote_1(ptr [[P]], i32 [[CALLER_C_VAL]])
; CHECK-NEXT: [[CALLER_C_VAL1:%.*]] = load i32, ptr [[CALLER_C]], align 4
; CHECK-NEXT: [[RES3:%.*]] = call i32 @test_can_promote_2(ptr [[P]], i32 [[CALLER_C_VAL1]])
; CHECK-NEXT: [[RES12:%.*]] = add i32 [[RES1]], [[RES2]]
; CHECK-NEXT: [[RES:%.*]] = add i32 [[RES12]], [[RES3]]
; CHECK-NEXT: ret i32 [[RES]]
Expand Down Expand Up @@ -215,7 +209,8 @@ define i32 @caller_safe_args_2(i64 %n, ptr %p) {
; CHECK-NEXT: call void @memset(ptr [[P]], i64 0, i64 [[N]])
; CHECK-NEXT: [[CALLER_C:%.*]] = alloca i32, align 4
; CHECK-NEXT: store i32 5, ptr [[CALLER_C]], align 4
; CHECK-NEXT: [[RES:%.*]] = call i32 @test_can_promote_2(ptr [[P]], ptr [[CALLER_C]])
; CHECK-NEXT: [[CALLER_C_VAL:%.*]] = load i32, ptr [[CALLER_C]], align 4
; CHECK-NEXT: [[RES:%.*]] = call i32 @test_can_promote_2(ptr [[P]], i32 [[CALLER_C_VAL]])
; CHECK-NEXT: ret i32 [[RES]]
;
call void @memset(ptr %p, i64 0, i64 %n)
Expand Down