Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions llvm/include/llvm/IR/AbstractCallSite.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class AbstractCallSite {

/// Return true if @p U is the use that defines the callee of this ACS.
bool isCallee(const Use *U) const {
if (isDirectCall())
if (!isCallbackCall())
return CB->isCallee(U);

assert(!CI.ParameterEncoding.empty() &&
Expand All @@ -154,7 +154,7 @@ class AbstractCallSite {

/// Return the number of parameters of the callee.
unsigned getNumArgOperands() const {
if (isDirectCall())
if (!isCallbackCall())
return CB->arg_size();
// Subtract 1 for the callee encoding.
return CI.ParameterEncoding.size() - 1;
Expand All @@ -169,7 +169,7 @@ class AbstractCallSite {
/// Return the operand index of the underlying instruction associated with
/// the function parameter number @p ArgNo or -1 if there is none.
int getCallArgOperandNo(unsigned ArgNo) const {
if (isDirectCall())
if (!isCallbackCall())
return ArgNo;
// Add 1 for the callee encoding.
return CI.ParameterEncoding[ArgNo + 1];
Expand All @@ -183,7 +183,7 @@ class AbstractCallSite {
/// Return the operand of the underlying instruction associated with the
/// function parameter number @p ArgNo or nullptr if there is none.
Value *getCallArgOperand(unsigned ArgNo) const {
if (isDirectCall())
if (!isCallbackCall())
return CB->getArgOperand(ArgNo);
// Add 1 for the callee encoding.
return CI.ParameterEncoding[ArgNo + 1] >= 0
Expand All @@ -210,7 +210,7 @@ class AbstractCallSite {

/// Return the pointer to function that is being called.
Value *getCalledOperand() const {
if (isDirectCall())
if (!isCallbackCall())
return CB->getCalledOperand();
return CB->getArgOperand(getCallArgOperandNoForCallee());
}
Expand Down
94 changes: 93 additions & 1 deletion llvm/unittests/IR/AbstractCallSiteTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/AbstractCallSite.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/SourceMgr.h"
Expand Down Expand Up @@ -51,5 +52,96 @@ TEST(AbstractCallSite, CallbackCall) {
EXPECT_TRUE(ACS);
EXPECT_TRUE(ACS.isCallbackCall());
EXPECT_TRUE(ACS.isCallee(CallbackUse));
EXPECT_EQ(ACS.getCalleeUseForCallback(), *CallbackUse);
EXPECT_EQ(ACS.getCalledFunction(), Callback);

// The callback metadata {CallbackNo, Arg0No, ..., isVarArg} = {1, -1, true}
EXPECT_EQ(ACS.getCallArgOperandNoForCallee(), 1);
// Though the callback metadata only specifies ONE unfixed argument No, the
// callback callee is vararg, hence the third arg is also considered as
// another arg for the callback.
EXPECT_EQ(ACS.getNumArgOperands(), 2u);
Argument *Param0 = Callback->getArg(0), *Param1 = Callback->getArg(1);
ASSERT_TRUE(Param0 && Param1);
EXPECT_EQ(ACS.getCallArgOperandNo(*Param0), -1);
EXPECT_EQ(ACS.getCallArgOperandNo(*Param1), 2);
}

TEST(AbstractCallSite, DirectCall) {
LLVMContext C;

const char *IR = "declare void @bar(i32 %x, i32 %y)\n"
"define void @foo() {\n"
" call void @bar(i32 1, i32 2)\n"
" ret void\n"
"}\n";

std::unique_ptr<Module> M = parseIR(C, IR);
ASSERT_TRUE(M);

Function *Callee = M->getFunction("bar");
ASSERT_NE(Callee, nullptr);

const Use *DirectCallUse = Callee->getSingleUndroppableUse();
ASSERT_NE(DirectCallUse, nullptr);

AbstractCallSite ACS(DirectCallUse);
EXPECT_TRUE(ACS);
EXPECT_TRUE(ACS.isDirectCall());
EXPECT_TRUE(ACS.isCallee(DirectCallUse));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also check the other methods like getNumArgOperands and getCallArgOperand?

EXPECT_EQ(ACS.getCalledFunction(), Callee);
EXPECT_EQ(ACS.getNumArgOperands(), 2u);
Argument *ArgX = Callee->getArg(0);
ASSERT_NE(ArgX, nullptr);
Value *CAO1 = ACS.getCallArgOperand(*ArgX);
Value *CAO2 = ACS.getCallArgOperand(0);
ASSERT_NE(CAO2, nullptr);
// The two call arg operands should be the same object, since they are both
// the first argument of the call.
EXPECT_EQ(CAO2, CAO1);

ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CAO2);
ASSERT_NE(FirstArgInt, nullptr);
EXPECT_EQ(FirstArgInt->getZExtValue(), 1ull);

EXPECT_EQ(ACS.getCallArgOperandNo(*ArgX), 0);
EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
}

TEST(AbstractCallSite, IndirectCall) {
LLVMContext C;

const char *IR = "define void @foo(ptr %0) {\n"
" call void %0(i32 1, i32 2)\n"
" ret void\n"
"}\n";

std::unique_ptr<Module> M = parseIR(C, IR);
ASSERT_TRUE(M);

Function *Fun = M->getFunction("foo");
ASSERT_NE(Fun, nullptr);

Argument *ArgAsCallee = Fun->getArg(0);
ASSERT_NE(ArgAsCallee, nullptr);

const Use *IndCallUse = ArgAsCallee->getSingleUndroppableUse();
ASSERT_NE(IndCallUse, nullptr);

AbstractCallSite ACS(IndCallUse);
EXPECT_TRUE(ACS);
EXPECT_TRUE(ACS.isIndirectCall());
EXPECT_TRUE(ACS.isCallee(IndCallUse));
EXPECT_EQ(ACS.getCalledFunction(), nullptr);
EXPECT_EQ(ACS.getCalledOperand(), ArgAsCallee);
EXPECT_EQ(ACS.getNumArgOperands(), 2u);
Value *CalledOperand = ACS.getCallArgOperand(0);
ASSERT_NE(CalledOperand, nullptr);
ConstantInt *FirstArgInt = dyn_cast<ConstantInt>(CalledOperand);
ASSERT_NE(FirstArgInt, nullptr);
EXPECT_EQ(FirstArgInt->getZExtValue(), 1ull);

EXPECT_EQ(ACS.getCallArgOperandNo(0), 0);
EXPECT_EQ(ACS.getCallArgOperandNo(1), 1);
}