Skip to content

Commit 1d85c1f

Browse files
Camsynaokblast
authored andcommitted
[AbstractCallSite] Handle Indirect Calls Properly (llvm#163003)
AbstractCallSite handles three types of calls (direct, indirect, and callback). This patch fixes the handling of indirect calls in some methods, which incorrectly assumed that non-direct calls are always callback calls. Moreover, this PR adds 2 unit tests for direct call type and indirect call type. The aforementioned misassumption leads to the following problem: --- ## Problem When the underlying call is **indirect**, some APIs of `AbstractCallSite` behave unexpectedly. E.g., `AbstractCallSite::getCalledFunction()` currently triggers an **assertion failure**, instead of returning `nullptr` as documented: ```cpp /// Return the function being called if this is a direct call, otherwise /// return null (if it's an indirect call). Function *getCalledFunction() const; ``` Actual unexpected assertion failure: ``` AbstractCallSite.h:197: int llvm::AbstractCallSite::getCallArgOperandNoForCallee() const: Assertion `isCallbackCall()' failed. ``` This is because `AbstractCallSite` mistakenly entered the branch that handles Callback Calls as its guard condition (`!isDirectCall()`) does not take into account the case of indirect calls
1 parent fea9fd2 commit 1d85c1f

File tree

2 files changed

+98
-6
lines changed

2 files changed

+98
-6
lines changed

llvm/include/llvm/IR/AbstractCallSite.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class AbstractCallSite {
137137

138138
/// Return true if @p U is the use that defines the callee of this ACS.
139139
bool isCallee(const Use *U) const {
140-
if (isDirectCall())
140+
if (!isCallbackCall())
141141
return CB->isCallee(U);
142142

143143
assert(!CI.ParameterEncoding.empty() &&
@@ -154,7 +154,7 @@ class AbstractCallSite {
154154

155155
/// Return the number of parameters of the callee.
156156
unsigned getNumArgOperands() const {
157-
if (isDirectCall())
157+
if (!isCallbackCall())
158158
return CB->arg_size();
159159
// Subtract 1 for the callee encoding.
160160
return CI.ParameterEncoding.size() - 1;
@@ -169,7 +169,7 @@ class AbstractCallSite {
169169
/// Return the operand index of the underlying instruction associated with
170170
/// the function parameter number @p ArgNo or -1 if there is none.
171171
int getCallArgOperandNo(unsigned ArgNo) const {
172-
if (isDirectCall())
172+
if (!isCallbackCall())
173173
return ArgNo;
174174
// Add 1 for the callee encoding.
175175
return CI.ParameterEncoding[ArgNo + 1];
@@ -183,7 +183,7 @@ class AbstractCallSite {
183183
/// Return the operand of the underlying instruction associated with the
184184
/// function parameter number @p ArgNo or nullptr if there is none.
185185
Value *getCallArgOperand(unsigned ArgNo) const {
186-
if (isDirectCall())
186+
if (!isCallbackCall())
187187
return CB->getArgOperand(ArgNo);
188188
// Add 1 for the callee encoding.
189189
return CI.ParameterEncoding[ArgNo + 1] >= 0
@@ -210,7 +210,7 @@ class AbstractCallSite {
210210

211211
/// Return the pointer to function that is being called.
212212
Value *getCalledOperand() const {
213-
if (isDirectCall())
213+
if (!isCallbackCall())
214214
return CB->getCalledOperand();
215215
return CB->getArgOperand(getCallArgOperandNoForCallee());
216216
}

llvm/unittests/IR/AbstractCallSiteTest.cpp

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "llvm/AsmParser/Parser.h"
109
#include "llvm/IR/AbstractCallSite.h"
10+
#include "llvm/AsmParser/Parser.h"
11+
#include "llvm/IR/Argument.h"
1112
#include "llvm/IR/Function.h"
1213
#include "llvm/IR/Module.h"
1314
#include "llvm/Support/SourceMgr.h"
@@ -51,5 +52,96 @@ TEST(AbstractCallSite, CallbackCall) {
5152
EXPECT_TRUE(ACS);
5253
EXPECT_TRUE(ACS.isCallbackCall());
5354
EXPECT_TRUE(ACS.isCallee(CallbackUse));
55+
EXPECT_EQ(ACS.getCalleeUseForCallback(), *CallbackUse);
5456
EXPECT_EQ(ACS.getCalledFunction(), Callback);
57+
58+
// The callback metadata {CallbackNo, Arg0No, ..., isVarArg} = {1, -1, true}
59+
EXPECT_EQ(ACS.getCallArgOperandNoForCallee(), 1);
60+
// Though the callback metadata only specifies ONE unfixed argument No, the
61+
// callback callee is vararg, hence the third arg is also considered as
62+
// another arg for the callback.
63+
EXPECT_EQ(ACS.getNumArgOperands(), 2u);
64+
Argument *Param0 = Callback->getArg(0), *Param1 = Callback->getArg(1);
65+
ASSERT_TRUE(Param0 && Param1);
66+
EXPECT_EQ(ACS.getCallArgOperandNo(*Param0), -1);
67+
EXPECT_EQ(ACS.getCallArgOperandNo(*Param1), 2);
68+
}
69+
70+
TEST(AbstractCallSite, DirectCall) {
71+
LLVMContext C;
72+
73+
const char *IR = "declare void @bar(i32 %x, i32 %y)\n"
74+
"define void @foo() {\n"
75+
" call void @bar(i32 1, i32 2)\n"
76+
" ret void\n"
77+
"}\n";
78+
79+
std::unique_ptr<Module> M = parseIR(C, IR);
80+
ASSERT_TRUE(M);
81+
82+
Function *Callee = M->getFunction("bar");
83+
ASSERT_NE(Callee, nullptr);
84+
85+
const Use *DirectCallUse = Callee->getSingleUndroppableUse();
86+
ASSERT_NE(DirectCallUse, nullptr);
87+
88+
AbstractCallSite ACS(DirectCallUse);
89+
EXPECT_TRUE(ACS);
90+
EXPECT_TRUE(ACS.isDirectCall());
91+
EXPECT_TRUE(ACS.isCallee(DirectCallUse));
92+
EXPECT_EQ(ACS.getCalledFunction(), Callee);
93+
EXPECT_EQ(ACS.getNumArgOperands(), 2u);
94+
Argument *ArgX = Callee->getArg(0);
95+
ASSERT_NE(ArgX, nullptr);
96+
Value *CAO1 = ACS.getCallArgOperand(*ArgX);
97+
Value *CAO2 = ACS.getCallArgOperand(0);
98+
ASSERT_NE(CAO2, nullptr);
99+
// The two call arg operands should be the same object, since they are both
100+
// the first argument of the call.
101+
EXPECT_EQ(CAO2, CAO1);
102+
103+
ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CAO2);
104+
ASSERT_NE(FirstArgInt, nullptr);
105+
EXPECT_EQ(FirstArgInt->getZExtValue(), 1ull);
106+
107+
EXPECT_EQ(ACS.getCallArgOperandNo(*ArgX), 0);
108+
EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
109+
EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
110+
}
111+
112+
TEST(AbstractCallSite, IndirectCall) {
113+
LLVMContext C;
114+
115+
const char *IR = "define void @foo(ptr %0) {\n"
116+
" call void %0(i32 1, i32 2)\n"
117+
" ret void\n"
118+
"}\n";
119+
120+
std::unique_ptr<Module> M = parseIR(C, IR);
121+
ASSERT_TRUE(M);
122+
123+
Function *Fun = M->getFunction("foo");
124+
ASSERT_NE(Fun, nullptr);
125+
126+
Argument *ArgAsCallee = Fun->getArg(0);
127+
ASSERT_NE(ArgAsCallee, nullptr);
128+
129+
const Use *IndCallUse = ArgAsCallee->getSingleUndroppableUse();
130+
ASSERT_NE(IndCallUse, nullptr);
131+
132+
AbstractCallSite ACS(IndCallUse);
133+
EXPECT_TRUE(ACS);
134+
EXPECT_TRUE(ACS.isIndirectCall());
135+
EXPECT_TRUE(ACS.isCallee(IndCallUse));
136+
EXPECT_EQ(ACS.getCalledFunction(), nullptr);
137+
EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee);
138+
EXPECT_EQ(ACS.getNumArgOperands(), 2u);
139+
Value *CalledOperand = ACS.getCallArgOperand(0);
140+
ASSERT_NE(CalledOperand, nullptr);
141+
ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CalledOperand);
142+
ASSERT_NE(FirstArgInt, nullptr);
143+
EXPECT_EQ(FirstArgInt->getZExtValue(), 1ull);
144+
145+
EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
146+
EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
55147
}

0 commit comments

Comments
 (0)