Skip to content

Commit 9ad23c9

Browse files
committed
[CIR] Add support for indirect calls
1 parent 802d8d9 commit 9ad23c9

File tree

12 files changed

+224
-34
lines changed

12 files changed

+224
-34
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
225225
callee.getFunctionType().getReturnType(), operands);
226226
}
227227

228+
cir::CallOp createIndirectCallOp(mlir::Location loc,
229+
mlir::Value indirectTarget,
230+
cir::FuncType funcType,
231+
mlir::ValueRange operands) {
232+
return create<cir::CallOp>(loc, indirectTarget, funcType.getReturnType(),
233+
operands);
234+
}
235+
228236
//===--------------------------------------------------------------------===//
229237
// Cast/Conversion Operators
230238
//===--------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,13 +1796,8 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
17961796
DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
17971797
let extraClassDeclaration = [{
17981798
/// Get the argument operands to the called function.
1799-
mlir::OperandRange getArgOperands() {
1800-
return getArgs();
1801-
}
1802-
1803-
mlir::MutableOperandRange getArgOperandsMutable() {
1804-
return getArgsMutable();
1805-
}
1799+
mlir::OperandRange getArgOperands();
1800+
mlir::MutableOperandRange getArgOperandsMutable();
18061801

18071802
/// Return the callee of this operation
18081803
mlir::CallInterfaceCallable getCallableForCallee() {
@@ -1824,6 +1819,9 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
18241819
::mlir::Attribute removeArgAttrsAttr() { return {}; }
18251820
::mlir::Attribute removeResAttrsAttr() { return {}; }
18261821

1822+
bool isIndirect() { return !getCallee(); }
1823+
mlir::Value getIndirectCall();
1824+
18271825
void setArg(unsigned index, mlir::Value value) {
18281826
setOperand(index, value);
18291827
}
@@ -1837,16 +1835,24 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
18371835
// the upstreaming process moves on. The verifiers is also missing for now,
18381836
// will add in the future.
18391837

1840-
dag commonArgs = (ins FlatSymbolRefAttr:$callee,
1841-
Variadic<CIR_AnyType>:$args);
1838+
dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
1839+
Variadic<CIR_AnyType>:$args);
18421840
}
18431841

18441842
def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
18451843
let summary = "call a function";
18461844
let description = [{
1847-
The `cir.call` operation represents a direct call to a function that is
1848-
within the same symbol scope as the call. The callee is encoded as a symbol
1849-
reference attribute named `callee`.
1845+
The `cir.call` operation represents a function call. It could represent
1846+
either a direct call or an indirect call.
1847+
1848+
If the operation represents a direct call, the callee should be defined
1849+
within the same symbol scope as the call. The `callee` attribute contains a
1850+
symbol reference to the callee function. All operands of this operation are
1851+
arguments to the callee function.
1852+
1853+
If the operation represents an indirect call, the `callee` attribute is
1854+
empty. The first operand of this operation must be a pointer to the callee
1855+
function. All the rest operands are arguments to the callee function.
18501856

18511857
Example:
18521858

@@ -1859,13 +1865,23 @@ def CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
18591865
let arguments = commonArgs;
18601866

18611867
let builders = [OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
1862-
"mlir::Type":$resType,
1863-
"mlir::ValueRange":$operands), [{
1868+
"mlir::Type":$resType,
1869+
"mlir::ValueRange":$operands),
1870+
[{
18641871
$_state.addOperands(operands);
18651872
$_state.addAttribute("callee", callee);
18661873
if (resType && !isa<VoidType>(resType))
18671874
$_state.addTypes(resType);
1868-
}]>];
1875+
}]>,
1876+
OpBuilder<(ins "mlir::Value":$callee, "mlir::Type":$resType,
1877+
"mlir::ValueRange":$operands),
1878+
[{
1879+
$_state.addOperands(callee);
1880+
$_state.addOperands(operands);
1881+
if (resType && !isa<VoidType>(resType))
1882+
$_state.addTypes(resType);
1883+
}]>,
1884+
];
18691885
}
18701886

18711887
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ struct MissingFeatures {
9393
static bool opCallChainCall() { return false; }
9494
static bool opCallNoPrototypeFunc() { return false; }
9595
static bool opCallMustTail() { return false; }
96-
static bool opCallIndirect() { return false; }
9796
static bool opCallVirtual() { return false; }
9897
static bool opCallInAlloca() { return false; }
9998
static bool opCallAttrs() { return false; }

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,27 @@ CIRGenFunctionInfo::create(CanQualType resultType,
3939
return fi;
4040
}
4141

42+
cir::FuncType CIRGenTypes::getFunctionType(const CIRGenFunctionInfo &info) {
43+
[[maybe_unused]] bool inserted = functionsBeingProcessed.insert(&info).second;
44+
assert(inserted && "Recursively being processed?");
45+
46+
mlir::Type resultType = convertType(info.getReturnType());
47+
SmallVector<mlir::Type, 8> argTypes;
48+
argTypes.reserve(info.getNumRequiredArgs());
49+
50+
// Add in all of the required arguments.
51+
for (const CIRGenFunctionInfoArgInfo &argInfo : info.requiredArguments())
52+
argTypes.push_back(convertType(argInfo.type));
53+
54+
[[maybe_unused]] bool erased = functionsBeingProcessed.erase(&info);
55+
assert(erased && "Not in set?");
56+
57+
assert(!cir::MissingFeatures::opCallVariadic());
58+
return cir::FuncType::get(argTypes,
59+
(resultType ? resultType : builder.getVoidTy()),
60+
/*isVarArg=*/false);
61+
}
62+
4263
CIRGenCallee CIRGenCallee::prepareConcreteCallee(CIRGenFunction &cgf) const {
4364
assert(!cir::MissingFeatures::opCallVirtual());
4465
return *this;
@@ -75,6 +96,7 @@ CIRGenTypes::arrangeFreeFunctionCall(const CallArgList &args,
7596

7697
static cir::CIRCallOpInterface
7798
emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
99+
cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
78100
cir::FuncOp directFuncOp,
79101
const SmallVectorImpl<mlir::Value> &cirCallArgs) {
80102
CIRGenBuilderTy &builder = cgf.getBuilder();
@@ -83,7 +105,13 @@ emitCallLikeOp(CIRGenFunction &cgf, mlir::Location callLoc,
83105
assert(!cir::MissingFeatures::invokeOp());
84106

85107
assert(builder.getInsertionBlock() && "expected valid basic block");
86-
assert(!cir::MissingFeatures::opCallIndirect());
108+
109+
if (indirectFuncTy) {
110+
// TODO(cir): Set calling convention for indirect calls.
111+
assert(!cir::MissingFeatures::opCallCallConv());
112+
return builder.createIndirectCallOp(callLoc, indirectFuncVal,
113+
indirectFuncTy, cirCallArgs);
114+
}
87115

88116
return builder.createCallOp(callLoc, directFuncOp, cirCallArgs);
89117
}
@@ -95,6 +123,7 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
95123
cir::CIRCallOpInterface *callOp,
96124
mlir::Location loc) {
97125
QualType retTy = funcInfo.getReturnType();
126+
cir::FuncType cirFuncTy = getTypes().getFunctionType(funcInfo);
98127

99128
SmallVector<mlir::Value, 16> cirCallArgs(args.size());
100129

@@ -145,12 +174,26 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
145174

146175
assert(!cir::MissingFeatures::invokeOp());
147176

148-
auto directFuncOp = dyn_cast<cir::FuncOp>(calleePtr);
149-
assert(!cir::MissingFeatures::opCallIndirect());
177+
cir::FuncType indirectFuncTy;
178+
mlir::Value indirectFuncVal;
179+
cir::FuncOp directFuncOp;
180+
if (auto fnOp = dyn_cast<cir::FuncOp>(calleePtr))
181+
directFuncOp = fnOp;
182+
else {
183+
[[maybe_unused]] auto resultTypes = calleePtr->getResultTypes();
184+
[[maybe_unused]] auto funcPtrTy =
185+
mlir::dyn_cast<cir::PointerType>(resultTypes.front());
186+
assert(funcPtrTy && mlir::isa<cir::FuncType>(funcPtrTy.getPointee()) &&
187+
"expected pointer to function");
188+
189+
indirectFuncTy = cirFuncTy;
190+
indirectFuncVal = calleePtr->getResult(0);
191+
}
192+
150193
assert(!cir::MissingFeatures::opCallAttrs());
151194

152-
cir::CIRCallOpInterface theCall =
153-
emitCallLikeOp(*this, loc, directFuncOp, cirCallArgs);
195+
cir::CIRCallOpInterface theCall = emitCallLikeOp(
196+
*this, loc, indirectFuncTy, indirectFuncVal, directFuncOp, cirCallArgs);
154197

155198
if (callOp)
156199
*callOp = theCall;
@@ -250,7 +293,7 @@ void CIRGenFunction::emitCallArgs(
250293

251294
auto maybeEmitImplicitObjectSize = [&](size_t i, const Expr *arg,
252295
RValue emittedArg) {
253-
if (callee.hasFunctionDecl() || i >= callee.getNumParams())
296+
if (!callee.hasFunctionDecl() || i >= callee.getNumParams())
254297
return;
255298
auto *ps = callee.getParamDecl(i)->getAttr<PassObjectSizeAttr>();
256299
if (!ps)

clang/lib/CIR/CodeGen/CIRGenCall.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,20 @@ class CIRGenFunction;
2525

2626
/// Abstract information about a function or function prototype.
2727
class CIRGenCalleeInfo {
28+
const clang::FunctionProtoType *calleeProtoTy;
2829
clang::GlobalDecl calleeDecl;
2930

3031
public:
31-
explicit CIRGenCalleeInfo() : calleeDecl() {}
32+
explicit CIRGenCalleeInfo() : calleeProtoTy(nullptr), calleeDecl() {}
33+
CIRGenCalleeInfo(const clang::FunctionProtoType *calleeProtoTy,
34+
clang::GlobalDecl calleeDecl)
35+
: calleeProtoTy(calleeProtoTy), calleeDecl(calleeDecl) {}
3236
CIRGenCalleeInfo(clang::GlobalDecl calleeDecl) : calleeDecl(calleeDecl) {}
37+
38+
const clang::FunctionProtoType *getCalleeFunctionProtoType() const {
39+
return calleeProtoTy;
40+
}
41+
clang::GlobalDecl getCalleeDecl() const { return calleeDecl; }
3342
};
3443

3544
class CIRGenCallee {

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,8 +915,28 @@ CIRGenCallee CIRGenFunction::emitCallee(const clang::Expr *e) {
915915
return emitDirectCallee(cgm, funcDecl);
916916
}
917917

918-
cgm.errorNYI(e->getSourceRange(), "Unsupported callee kind");
919-
return {};
918+
assert(!cir::MissingFeatures::opCallPseudoDtor());
919+
920+
// Otherwise, we have an indirect reference.
921+
mlir::Value calleePtr;
922+
QualType functionType;
923+
if (const auto *ptrType = e->getType()->getAs<clang::PointerType>()) {
924+
calleePtr = emitScalarExpr(e);
925+
functionType = ptrType->getPointeeType();
926+
} else {
927+
functionType = e->getType();
928+
calleePtr = emitLValue(e).getPointer();
929+
}
930+
assert(functionType->isFunctionType());
931+
932+
GlobalDecl gd;
933+
if (const auto *vd =
934+
dyn_cast_or_null<VarDecl>(e->getReferencedDeclOfCallee()))
935+
gd = GlobalDecl(vd);
936+
937+
CIRGenCalleeInfo calleeInfo(functionType->getAs<FunctionProtoType>(), gd);
938+
CIRGenCallee callee(calleeInfo, calleePtr.getDefiningOp());
939+
return callee;
920940
}
921941

922942
RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *e,

clang/lib/CIR/CodeGen/CIRGenFunctionInfo.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define LLVM_CLANG_CIR_CIRGENFUNCTIONINFO_H
1717

1818
#include "clang/AST/CanonicalType.h"
19+
#include "clang/CIR/MissingFeatures.h"
1920
#include "llvm/ADT/FoldingSet.h"
2021
#include "llvm/Support/TrailingObjects.h"
2122

@@ -67,6 +68,13 @@ class CIRGenFunctionInfo final
6768
return llvm::ArrayRef<ArgInfo>(arg_begin(), numArgs);
6869
}
6970

71+
llvm::MutableArrayRef<ArgInfo> requiredArguments() {
72+
return llvm::MutableArrayRef<ArgInfo>(arg_begin(), getNumRequiredArgs());
73+
}
74+
llvm::ArrayRef<ArgInfo> requiredArguments() const {
75+
return llvm::ArrayRef<ArgInfo>(arg_begin(), getNumRequiredArgs());
76+
}
77+
7078
const_arg_iterator arg_begin() const { return getArgsBuffer() + 1; }
7179
const_arg_iterator arg_end() const { return getArgsBuffer() + 1 + numArgs; }
7280
arg_iterator arg_begin() { return getArgsBuffer() + 1; }
@@ -75,6 +83,11 @@ class CIRGenFunctionInfo final
7583
unsigned arg_size() const { return numArgs; }
7684

7785
CanQualType getReturnType() const { return getArgsBuffer()[0].type; }
86+
87+
unsigned getNumRequiredArgs() const {
88+
assert(!cir::MissingFeatures::opCallVariadic());
89+
return arg_size();
90+
}
7891
};
7992

8093
} // namespace clang::CIRGen

clang/lib/CIR/CodeGen/CIRGenTypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ class CIRGenTypes {
117117
// TODO: convert this comment to account for MLIR's equivalence
118118
mlir::Type convertTypeForMem(clang::QualType, bool forBitField = false);
119119

120+
/// Get the CIR function type for \arg Info.
121+
cir::FuncType getFunctionType(const CIRGenFunctionInfo &info);
122+
120123
/// Return whether a type can be zero-initialized (in the C++ sense) with an
121124
/// LLVM zeroinitializer.
122125
bool isZeroInitializable(clang::QualType ty);

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -464,15 +464,35 @@ OpFoldResult cir::CastOp::fold(FoldAdaptor adaptor) {
464464
// CallOp
465465
//===----------------------------------------------------------------------===//
466466

467+
mlir::OperandRange cir::CallOp::getArgOperands() {
468+
if (isIndirect())
469+
return getArgs().drop_front(1);
470+
return getArgs();
471+
}
472+
473+
mlir::MutableOperandRange cir::CallOp::getArgOperandsMutable() {
474+
mlir::MutableOperandRange args = getArgsMutable();
475+
if (isIndirect())
476+
return args.slice(1, args.size() - 1);
477+
return args;
478+
}
479+
480+
mlir::Value cir::CallOp::getIndirectCall() {
481+
assert(isIndirect());
482+
return getOperand(0);
483+
}
484+
467485
/// Return the operand at index 'i'.
468486
Value cir::CallOp::getArgOperand(unsigned i) {
469-
assert(!cir::MissingFeatures::opCallIndirect());
487+
if (isIndirect())
488+
++i;
470489
return getOperand(i);
471490
}
472491

473492
/// Return the number of operands.
474493
unsigned cir::CallOp::getNumArgOperands() {
475-
assert(!cir::MissingFeatures::opCallIndirect());
494+
if (isIndirect())
495+
return this->getOperation()->getNumOperands() - 1;
476496
return this->getOperation()->getNumOperands();
477497
}
478498

@@ -483,9 +503,15 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
483503
mlir::FlatSymbolRefAttr calleeAttr;
484504
llvm::ArrayRef<mlir::Type> allResultTypes;
485505

506+
// If we cannot parse a string callee, it means this is an indirect call.
486507
if (!parser.parseOptionalAttribute(calleeAttr, "callee", result.attributes)
487-
.has_value())
488-
return mlir::failure();
508+
.has_value()) {
509+
OpAsmParser::UnresolvedOperand indirectVal;
510+
// Do not resolve right now, since we need to figure out the type
511+
if (parser.parseOperand(indirectVal).failed())
512+
return failure();
513+
ops.push_back(indirectVal);
514+
}
489515

490516
if (parser.parseLParen())
491517
return mlir::failure();
@@ -517,13 +543,21 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
517543

518544
static void printCallCommon(mlir::Operation *op,
519545
mlir::FlatSymbolRefAttr calleeSym,
546+
mlir::Value indirectCallee,
520547
mlir::OpAsmPrinter &printer) {
521548
printer << ' ';
522549

523550
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
524551
auto ops = callLikeOp.getArgOperands();
525552

526-
printer.printAttributeWithoutType(calleeSym);
553+
if (calleeSym) {
554+
// Direct calls
555+
printer.printAttributeWithoutType(calleeSym);
556+
} else {
557+
// Indirect calls
558+
assert(indirectCallee);
559+
printer << indirectCallee;
560+
}
527561
printer << "(" << ops << ")";
528562

529563
printer.printOptionalAttrDict(op->getAttrs(), {"callee"});
@@ -539,15 +573,16 @@ mlir::ParseResult cir::CallOp::parse(mlir::OpAsmParser &parser,
539573
}
540574

541575
void cir::CallOp::print(mlir::OpAsmPrinter &p) {
542-
printCallCommon(*this, getCalleeAttr(), p);
576+
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
577+
printCallCommon(*this, getCalleeAttr(), indirectCallee, p);
543578
}
544579

545580
static LogicalResult
546581
verifyCallCommInSymbolUses(mlir::Operation *op,
547582
SymbolTableCollection &symbolTable) {
548583
auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee");
549584
if (!fnAttr)
550-
return mlir::failure();
585+
return mlir::success();
551586

552587
auto fn = symbolTable.lookupNearestSymbolFrom<cir::FuncOp>(op, fnAttr);
553588
if (!fn)

0 commit comments

Comments
 (0)