Skip to content

Commit 32779cd

Browse files
authored
[CIR] Add proper handling for no prototype function calls (#150553)
This adds standard-comforming handling for calls to functions that were declared in C source in the no prototype form.
1 parent 0a4c652 commit 32779cd

File tree

11 files changed

+264
-31
lines changed

11 files changed

+264
-31
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,6 +1946,10 @@ def CIR_FuncOp : CIR_Op<"func", [
19461946
The function linkage information is specified by `linkage`, as defined by
19471947
`GlobalLinkageKind` attribute.
19481948

1949+
The `no_proto` keyword is used to identify functions that were declared
1950+
without a prototype and, consequently, may contain calls with invalid
1951+
arguments and undefined behavior.
1952+
19491953
Example:
19501954

19511955
```mlir
@@ -1964,6 +1968,7 @@ def CIR_FuncOp : CIR_Op<"func", [
19641968
let arguments = (ins SymbolNameAttr:$sym_name,
19651969
CIR_VisibilityAttr:$global_visibility,
19661970
TypeAttrOf<CIR_FuncType>:$function_type,
1971+
UnitAttr:$no_proto,
19671972
UnitAttr:$dso_local,
19681973
DefaultValuedAttr<CIR_GlobalLinkageKind,
19691974
"cir::GlobalLinkageKind::ExternalLinkage">:$linkage,
@@ -2005,13 +2010,6 @@ def CIR_FuncOp : CIR_Op<"func", [
20052010
return getFunctionType().getReturnTypes();
20062011
}
20072012

2008-
// TODO(cir): this should be an operand attribute, but for now we just hard-
2009-
// wire this as a function. Will later add a $no_proto argument to this op.
2010-
bool getNoProto() {
2011-
assert(!cir::MissingFeatures::opFuncNoProto());
2012-
return false;
2013-
}
2014-
20152013
//===------------------------------------------------------------------===//
20162014
// SymbolOpInterface Methods
20172015
//===------------------------------------------------------------------===//

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,16 @@ struct MissingFeatures {
7373
// FuncOp handling
7474
static bool opFuncOpenCLKernelMetadata() { return false; }
7575
static bool opFuncAstDeclAttr() { return false; }
76+
static bool opFuncAttributesForDefinition() { return false; }
7677
static bool opFuncCallingConv() { return false; }
77-
static bool opFuncExtraAttrs() { return false; }
78-
static bool opFuncNoProto() { return false; }
7978
static bool opFuncCPUAndFeaturesAttributes() { return false; }
80-
static bool opFuncSection() { return false; }
81-
static bool opFuncMultipleReturnVals() { return false; }
82-
static bool opFuncAttributesForDefinition() { return false; }
79+
static bool opFuncExceptions() { return false; }
80+
static bool opFuncExtraAttrs() { return false; }
8381
static bool opFuncMaybeHandleStaticInExternC() { return false; }
82+
static bool opFuncMultipleReturnVals() { return false; }
83+
static bool opFuncOperandBundles() { return false; }
84+
static bool opFuncParameterAttributes() { return false; }
85+
static bool opFuncSection() { return false; }
8486
static bool setLLVMFunctionFEnvAttributes() { return false; }
8587
static bool setFunctionAttributes() { return false; }
8688

@@ -96,7 +98,6 @@ struct MissingFeatures {
9698
static bool opCallReturn() { return false; }
9799
static bool opCallArgEvaluationOrder() { return false; }
98100
static bool opCallCallConv() { return false; }
99-
static bool opCallNoPrototypeFunc() { return false; }
100101
static bool opCallMustTail() { return false; }
101102
static bool opCallVirtual() { return false; }
102103
static bool opCallInAlloca() { return false; }
@@ -109,6 +110,7 @@ struct MissingFeatures {
109110
static bool opCallCIRGenFuncInfoExtParamInfo() { return false; }
110111
static bool opCallLandingPad() { return false; }
111112
static bool opCallContinueBlock() { return false; }
113+
static bool opCallChain() { return false; }
112114

113115
// CXXNewExpr
114116
static bool exprNewNullCheck() { return false; }

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,14 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &funcInfo,
582582
cir::FuncOp directFuncOp;
583583
if (auto fnOp = dyn_cast<cir::FuncOp>(calleePtr)) {
584584
directFuncOp = fnOp;
585+
} else if (auto getGlobalOp = mlir::dyn_cast<cir::GetGlobalOp>(calleePtr)) {
586+
// FIXME(cir): This peephole optimization avoids indirect calls for
587+
// builtins. This should be fixed in the builtin declaration instead by
588+
// not emitting an unecessary get_global in the first place.
589+
// However, this is also used for no-prototype functions.
590+
mlir::Operation *globalOp = cgm.getGlobalValue(getGlobalOp.getName());
591+
assert(globalOp && "undefined global function");
592+
directFuncOp = mlir::cast<cir::FuncOp>(globalOp);
585593
} else {
586594
[[maybe_unused]] mlir::ValueTypeRange<mlir::ResultRange> resultTypes =
587595
calleePtr->getResultTypes();

clang/lib/CIR/CodeGen/CIRGenCall.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ class CIRGenCallee {
116116
assert(isOrdinary());
117117
return reinterpret_cast<mlir::Operation *>(kindOrFunctionPtr);
118118
}
119+
120+
void setFunctionPointer(mlir::Operation *functionPtr) {
121+
assert(isOrdinary());
122+
kindOrFunctionPtr = SpecialKind(reinterpret_cast<uintptr_t>(functionPtr));
123+
}
119124
};
120125

121126
/// Type for representing both the decl and type of parameters to a function.

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,7 @@ RValue CIRGenFunction::getUndefRValue(QualType ty) {
12801280
}
12811281

12821282
RValue CIRGenFunction::emitCall(clang::QualType calleeTy,
1283-
const CIRGenCallee &callee,
1283+
const CIRGenCallee &origCallee,
12841284
const clang::CallExpr *e,
12851285
ReturnValueSlot returnValue) {
12861286
// Get the actual function type. The callee type will always be a pointer to
@@ -1291,6 +1291,8 @@ RValue CIRGenFunction::emitCall(clang::QualType calleeTy,
12911291
calleeTy = getContext().getCanonicalType(calleeTy);
12921292
auto pointeeTy = cast<PointerType>(calleeTy)->getPointeeType();
12931293

1294+
CIRGenCallee callee = origCallee;
1295+
12941296
if (getLangOpts().CPlusPlus)
12951297
assert(!cir::MissingFeatures::sanitizers());
12961298

@@ -1307,7 +1309,44 @@ RValue CIRGenFunction::emitCall(clang::QualType calleeTy,
13071309
const CIRGenFunctionInfo &funcInfo =
13081310
cgm.getTypes().arrangeFreeFunctionCall(args, fnType);
13091311

1310-
assert(!cir::MissingFeatures::opCallNoPrototypeFunc());
1312+
// C99 6.5.2.2p6:
1313+
// If the expression that denotes the called function has a type that does
1314+
// not include a prototype, [the default argument promotions are performed].
1315+
// If the number of arguments does not equal the number of parameters, the
1316+
// behavior is undefined. If the function is defined with a type that
1317+
// includes a prototype, and either the prototype ends with an ellipsis (,
1318+
// ...) or the types of the arguments after promotion are not compatible
1319+
// with the types of the parameters, the behavior is undefined. If the
1320+
// function is defined with a type that does not include a prototype, and
1321+
// the types of the arguments after promotion are not compatible with those
1322+
// of the parameters after promotion, the behavior is undefined [except in
1323+
// some trivial cases].
1324+
// That is, in the general case, we should assume that a call through an
1325+
// unprototyped function type works like a *non-variadic* call. The way we
1326+
// make this work is to cast to the exxact type fo the promoted arguments.
1327+
if (isa<FunctionNoProtoType>(fnType)) {
1328+
assert(!cir::MissingFeatures::opCallChain());
1329+
assert(!cir::MissingFeatures::addressSpace());
1330+
cir::FuncType calleeTy = getTypes().getFunctionType(funcInfo);
1331+
// get non-variadic function type
1332+
calleeTy = cir::FuncType::get(calleeTy.getInputs(),
1333+
calleeTy.getReturnType(), false);
1334+
auto calleePtrTy = cir::PointerType::get(calleeTy);
1335+
1336+
mlir::Operation *fn = callee.getFunctionPointer();
1337+
mlir::Value addr;
1338+
if (auto funcOp = mlir::dyn_cast<cir::FuncOp>(fn)) {
1339+
addr = builder.create<cir::GetGlobalOp>(
1340+
getLoc(e->getSourceRange()),
1341+
cir::PointerType::get(funcOp.getFunctionType()), funcOp.getSymName());
1342+
} else {
1343+
addr = fn->getResult(0);
1344+
}
1345+
1346+
fn = builder.createBitcast(addr, calleePtrTy).getDefiningOp();
1347+
callee.setFunctionPointer(fn);
1348+
}
1349+
13111350
assert(!cir::MissingFeatures::opCallFnInfoOpts());
13121351
assert(!cir::MissingFeatures::hip());
13131352
assert(!cir::MissingFeatures::opCallMustTail());

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,60 @@ cir::GlobalLinkageKind CIRGenModule::getCIRLinkageForDeclarator(
11031103
return cir::GlobalLinkageKind::ExternalLinkage;
11041104
}
11051105

1106+
/// This function is called when we implement a function with no prototype, e.g.
1107+
/// "int foo() {}". If there are existing call uses of the old function in the
1108+
/// module, this adjusts them to call the new function directly.
1109+
///
1110+
/// This is not just a cleanup: the always_inline pass requires direct calls to
1111+
/// functions to be able to inline them. If there is a bitcast in the way, it
1112+
/// won't inline them. Instcombine normally deletes these calls, but it isn't
1113+
/// run at -O0.
1114+
void CIRGenModule::replaceUsesOfNonProtoTypeWithRealFunction(
1115+
mlir::Operation *old, cir::FuncOp newFn) {
1116+
// If we're redefining a global as a function, don't transform it.
1117+
auto oldFn = mlir::dyn_cast<cir::FuncOp>(old);
1118+
if (!oldFn)
1119+
return;
1120+
1121+
// TODO(cir): this RAUW ignores the features below.
1122+
assert(!cir::MissingFeatures::opFuncExceptions());
1123+
assert(!cir::MissingFeatures::opFuncParameterAttributes());
1124+
assert(!cir::MissingFeatures::opFuncOperandBundles());
1125+
if (oldFn->getAttrs().size() <= 1)
1126+
errorNYI(old->getLoc(),
1127+
"replaceUsesOfNonProtoTypeWithRealFunction: Attribute forwarding");
1128+
1129+
// Mark new function as originated from a no-proto declaration.
1130+
newFn.setNoProto(oldFn.getNoProto());
1131+
1132+
// Iterate through all calls of the no-proto function.
1133+
std::optional<mlir::SymbolTable::UseRange> symUses =
1134+
oldFn.getSymbolUses(oldFn->getParentOp());
1135+
for (const mlir::SymbolTable::SymbolUse &use : symUses.value()) {
1136+
mlir::OpBuilder::InsertionGuard guard(builder);
1137+
1138+
if (auto noProtoCallOp = mlir::dyn_cast<cir::CallOp>(use.getUser())) {
1139+
builder.setInsertionPoint(noProtoCallOp);
1140+
1141+
// Patch call type with the real function type.
1142+
cir::CallOp realCallOp = builder.createCallOp(
1143+
noProtoCallOp.getLoc(), newFn, noProtoCallOp.getOperands());
1144+
1145+
// Replace old no proto call with fixed call.
1146+
noProtoCallOp.replaceAllUsesWith(realCallOp);
1147+
noProtoCallOp.erase();
1148+
} else if (auto getGlobalOp =
1149+
mlir::dyn_cast<cir::GetGlobalOp>(use.getUser())) {
1150+
// Replace type
1151+
getGlobalOp.getAddr().setType(
1152+
cir::PointerType::get(newFn.getFunctionType()));
1153+
} else {
1154+
errorNYI(use.getUser()->getLoc(),
1155+
"replaceUsesOfNonProtoTypeWithRealFunction: unexpected use");
1156+
}
1157+
}
1158+
}
1159+
11061160
cir::GlobalLinkageKind
11071161
CIRGenModule::getCIRLinkageVarDefinition(const VarDecl *vd, bool isConstant) {
11081162
assert(!isConstant && "constant variables NYI");
@@ -1701,8 +1755,7 @@ cir::FuncOp CIRGenModule::getOrCreateCIRFunction(
17011755
// Lookup the entry, lazily creating it if necessary.
17021756
mlir::Operation *entry = getGlobalValue(mangledName);
17031757
if (entry) {
1704-
if (!isa<cir::FuncOp>(entry))
1705-
errorNYI(d->getSourceRange(), "getOrCreateCIRFunction: non-FuncOp");
1758+
assert(mlir::isa<cir::FuncOp>(entry));
17061759

17071760
assert(!cir::MissingFeatures::weakRefReference());
17081761

@@ -1738,6 +1791,30 @@ cir::FuncOp CIRGenModule::getOrCreateCIRFunction(
17381791
invalidLoc ? theModule->getLoc() : getLoc(funcDecl->getSourceRange()),
17391792
mangledName, mlir::cast<cir::FuncType>(funcType), funcDecl);
17401793

1794+
// If we already created a function with the same mangled name (but different
1795+
// type) before, take its name and add it to the list of functions to be
1796+
// replaced with F at the end of CodeGen.
1797+
//
1798+
// This happens if there is a prototype for a function (e.g. "int f()") and
1799+
// then a definition of a different type (e.g. "int f(int x)").
1800+
if (entry) {
1801+
1802+
// Fetch a generic symbol-defining operation and its uses.
1803+
auto symbolOp = mlir::cast<mlir::SymbolOpInterface>(entry);
1804+
1805+
// This might be an implementation of a function without a prototype, in
1806+
// which case, try to do special replacement of calls which match the new
1807+
// prototype. The really key thing here is that we also potentially drop
1808+
// arguments from the call site so as to make a direct call, which makes the
1809+
// inliner happier and suppresses a number of optimizer warnings (!) about
1810+
// dropping arguments.
1811+
if (symbolOp.getSymbolUses(symbolOp->getParentOp()))
1812+
replaceUsesOfNonProtoTypeWithRealFunction(entry, funcOp);
1813+
1814+
// Obliterate no-proto declaration.
1815+
entry->erase();
1816+
}
1817+
17411818
if (d)
17421819
setFunctionAttributes(gd, funcOp, /*isIncompleteFunction=*/false, isThunk);
17431820

@@ -1814,7 +1891,9 @@ CIRGenModule::createCIRFunction(mlir::Location loc, StringRef name,
18141891
func = builder.create<cir::FuncOp>(loc, name, funcType);
18151892

18161893
assert(!cir::MissingFeatures::opFuncAstDeclAttr());
1817-
assert(!cir::MissingFeatures::opFuncNoProto());
1894+
1895+
if (funcDecl && !funcDecl->hasPrototype())
1896+
func.setNoProto(true);
18181897

18191898
assert(func.isDeclaration() && "expected empty body");
18201899

clang/lib/CIR/CodeGen/CIRGenModule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,9 @@ class CIRGenModule : public CIRGenTypeCache {
313313

314314
static void setInitializer(cir::GlobalOp &op, mlir::Attribute value);
315315

316+
void replaceUsesOfNonProtoTypeWithRealFunction(mlir::Operation *old,
317+
cir::FuncOp newFn);
318+
316319
cir::FuncOp
317320
getOrCreateCIRFunction(llvm::StringRef mangledName, mlir::Type funcType,
318321
clang::GlobalDecl gd, bool forVTable,

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,10 +1470,14 @@ ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
14701470
llvm::SMLoc loc = parser.getCurrentLocation();
14711471
mlir::Builder &builder = parser.getBuilder();
14721472

1473+
mlir::StringAttr noProtoNameAttr = getNoProtoAttrName(state.name);
14731474
mlir::StringAttr visNameAttr = getSymVisibilityAttrName(state.name);
14741475
mlir::StringAttr visibilityNameAttr = getGlobalVisibilityAttrName(state.name);
14751476
mlir::StringAttr dsoLocalNameAttr = getDsoLocalAttrName(state.name);
14761477

1478+
if (parser.parseOptionalKeyword(noProtoNameAttr).succeeded())
1479+
state.addAttribute(noProtoNameAttr, parser.getBuilder().getUnitAttr());
1480+
14771481
// Default to external linkage if no keyword is provided.
14781482
state.addAttribute(getLinkageAttrNameString(),
14791483
GlobalLinkageKindAttr::get(
@@ -1578,6 +1582,9 @@ mlir::Region *cir::FuncOp::getCallableRegion() {
15781582
}
15791583

15801584
void cir::FuncOp::print(OpAsmPrinter &p) {
1585+
if (getNoProto())
1586+
p << " no_proto";
1587+
15811588
if (getComdat())
15821589
p << " comdat";
15831590

clang/test/CIR/CodeGen/call.c

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ struct S {
1111
};
1212

1313
void f1(struct S);
14-
void f2() {
14+
void f2(void) {
1515
struct S s;
1616
f1(s);
1717
}
@@ -28,8 +28,8 @@ void f2() {
2828
// OGCG: %[[S:.+]] = load i64, ptr %{{.+}}, align 4
2929
// OGCG-NEXT: call void @f1(i64 %[[S]])
3030

31-
struct S f3();
32-
void f4() {
31+
struct S f3(void);
32+
void f4(void) {
3333
struct S s = f3();
3434
}
3535

@@ -38,21 +38,21 @@ void f4() {
3838
// CIR-NEXT: cir.store align(4) %[[S]], %{{.+}} : !rec_S, !cir.ptr<!rec_S>
3939

4040
// LLVM-LABEL: define{{.*}} void @f4() {
41-
// LLVM: %[[S:.+]] = call %struct.S (...) @f3()
41+
// LLVM: %[[S:.+]] = call %struct.S @f3()
4242
// LLVM-NEXT: store %struct.S %[[S]], ptr %{{.+}}, align 4
4343

4444
// OGCG-LABEL: define{{.*}} void @f4() #0 {
45-
// OGCG: %[[S:.+]] = call i64 (...) @f3()
45+
// OGCG: %[[S:.+]] = call i64 @f3()
4646
// OGCG-NEXT: store i64 %[[S]], ptr %{{.+}}, align 4
4747

4848
struct Big {
4949
int data[10];
5050
};
5151

5252
void f5(struct Big);
53-
struct Big f6();
53+
struct Big f6(void);
5454

55-
void f7() {
55+
void f7(void) {
5656
struct Big b;
5757
f5(b);
5858
}
@@ -69,7 +69,7 @@ void f7() {
6969
// OGCG: %[[B:.+]] = alloca %struct.Big, align 8
7070
// OGCG-NEXT: call void @f5(ptr noundef byval(%struct.Big) align 8 %[[B]])
7171

72-
void f8() {
72+
void f8(void) {
7373
struct Big b = f6();
7474
}
7575

@@ -78,14 +78,14 @@ void f8() {
7878
// CIR: cir.store align(4) %[[B]], %{{.+}} : !rec_Big, !cir.ptr<!rec_Big>
7979

8080
// LLVM-LABEL: define{{.*}} void @f8() {
81-
// LLVM: %[[B:.+]] = call %struct.Big (...) @f6()
81+
// LLVM: %[[B:.+]] = call %struct.Big @f6()
8282
// LLVM-NEXT: store %struct.Big %[[B]], ptr %{{.+}}, align 4
8383

8484
// OGCG-LABEL: define{{.*}} void @f8() #0 {
8585
// OGCG: %[[B:.+]] = alloca %struct.Big, align 4
86-
// OGCG-NEXT: call void (ptr, ...) @f6(ptr dead_on_unwind writable sret(%struct.Big) align 4 %[[B]])
86+
// OGCG-NEXT: call void @f6(ptr dead_on_unwind writable sret(%struct.Big) align 4 %[[B]])
8787

88-
void f9() {
88+
void f9(void) {
8989
f1(f3());
9090
}
9191

@@ -98,14 +98,14 @@ void f9() {
9898

9999
// LLVM-LABEL: define{{.*}} void @f9() {
100100
// LLVM: %[[SLOT:.+]] = alloca %struct.S, i64 1, align 4
101-
// LLVM-NEXT: %[[RET:.+]] = call %struct.S (...) @f3()
101+
// LLVM-NEXT: %[[RET:.+]] = call %struct.S @f3()
102102
// LLVM-NEXT: store %struct.S %[[RET]], ptr %[[SLOT]], align 4
103103
// LLVM-NEXT: %[[ARG:.+]] = load %struct.S, ptr %[[SLOT]], align 4
104104
// LLVM-NEXT: call void @f1(%struct.S %[[ARG]])
105105

106106
// OGCG-LABEL: define{{.*}} void @f9() #0 {
107107
// OGCG: %[[SLOT:.+]] = alloca %struct.S, align 4
108-
// OGCG-NEXT: %[[RET:.+]] = call i64 (...) @f3()
108+
// OGCG-NEXT: %[[RET:.+]] = call i64 @f3()
109109
// OGCG-NEXT: store i64 %[[RET]], ptr %[[SLOT]], align 4
110110
// OGCG-NEXT: %[[ARG:.+]] = load i64, ptr %[[SLOT]], align 4
111111
// OGCG-NEXT: call void @f1(i64 %[[ARG]])

0 commit comments

Comments
 (0)