Skip to content

Commit 3e396ab

Browse files
committed
[CIR] Function type return type improvements
When a C or C++ function has a return type of `void`, the function type is now represented in MLIR as having no return type rather than having a return type of `!cir.void`. This avoids breaking MLIR invariants that require the number of return types and the number of return values to match. Change the assembly format for `cir::FuncType` from having a leading return type to having a trailing return type. In other words, change ``` !cir.func<!returnType (!argTypes)> ``` to ``` !cir.func<(!argTypes) -> !returnType)> ``` Unless the function returns `void`, in which case change ``` !cir.func<!cir.void (!argTypes)> ``` to ``` !cir.func<(!argTypes)> ```
1 parent ad94af9 commit 3e396ab

File tree

8 files changed

+141
-60
lines changed

8 files changed

+141
-60
lines changed

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -287,32 +287,43 @@ def CIR_BoolType :
287287
def CIR_FuncType : CIR_Type<"Func", "func"> {
288288
let summary = "CIR function type";
289289
let description = [{
290-
The `!cir.func` is a function type. It consists of a single return type, a
291-
list of parameter types and can optionally be variadic.
290+
The `!cir.func` is a function type. It consists of an optional return type,
291+
a list of parameter types and can optionally be variadic.
292292

293293
Example:
294294

295295
```mlir
296-
!cir.func<!bool ()>
297-
!cir.func<!s32i (!s8i, !s8i)>
298-
!cir.func<!s32i (!s32i, ...)>
296+
!cir.func<()>
297+
!cir.func<() -> bool>
298+
!cir.func<(!s8i, !s8i)>
299+
!cir.func<(!s8i, !s8i) -> !s32i>
300+
!cir.func<(!s32i, ...) -> !s32i>
299301
```
300302
}];
301303

302304
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs,
303-
"mlir::Type":$returnType, "bool":$varArg);
305+
"mlir::Type":$optionalReturnType, "bool":$varArg);
306+
// Use a custom parser to handle the argument types and optional return
304307
let assemblyFormat = [{
305-
`<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
308+
`<` custom<FuncType>($optionalReturnType, $inputs, $varArg) `>`
306309
}];
307310

308311
let builders = [
312+
// Create a FuncType, converting the return type from C-style to
313+
// MLIR-style. If the given return type is `cir::VoidType`, ignore it
314+
// and create the FuncType with no return type, which is how MLIR
315+
// represents function types.
309316
TypeBuilderWithInferredContext<(ins
310317
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
311318
CArg<"bool", "false">:$isVarArg), [{
312-
return $_get(returnType.getContext(), inputs, returnType, isVarArg);
319+
return $_get(returnType.getContext(), inputs,
320+
mlir::isa<cir::VoidType>(returnType) ? nullptr : returnType,
321+
isVarArg);
313322
}]>
314323
];
315324

325+
let genVerifyDecl = 1;
326+
316327
let extraClassDeclaration = [{
317328
/// Returns whether the function is variadic.
318329
bool isVarArg() const { return getVarArg(); }
@@ -323,12 +334,17 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
323334
/// Returns the number of arguments to the function.
324335
unsigned getNumInputs() const { return getInputs().size(); }
325336

326-
/// Returns the result type of the function as an ArrayRef, enabling better
327-
/// integration with generic MLIR utilities.
337+
/// Get the C-style return type of the function, which is !cir.void if the
338+
/// function returns nothing and the actual return type otherwise.
339+
mlir::Type getReturnType() const;
340+
341+
/// Get the MLIR-style return type of the function, which is an empty
342+
/// ArrayRef if the function returns nothing and a single-element ArrayRef
343+
/// with the actual return type otherwise.
328344
llvm::ArrayRef<mlir::Type> getReturnTypes() const;
329345

330-
/// Returns whether the function is returns void.
331-
bool isVoid() const;
346+
/// Does the function type return nothing?
347+
bool hasVoidReturn() const;
332348

333349
/// Returns a clone of this function type with the given argument
334350
/// and result types.

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ bool CIRGenTypes::isFuncTypeConvertible(const FunctionType *ft) {
6060
mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType qft) {
6161
assert(qft.isCanonical());
6262
const FunctionType *ft = cast<FunctionType>(qft.getTypePtr());
63-
// First, check whether we can build the full fucntion type. If the function
63+
// First, check whether we can build the full function type. If the function
6464
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
6565
// the function type.
6666
if (!isFuncTypeConvertible(ft)) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,10 @@ LogicalResult cir::FuncOp::verifyType() {
424424
if (!isa<cir::FuncType>(type))
425425
return emitOpError("requires '" + getFunctionTypeAttrName().str() +
426426
"' attribute of function type");
427+
if (auto rt = type.getReturnTypes();
428+
!rt.empty() && mlir::isa<cir::VoidType>(rt.front()))
429+
return emitOpError("The return type for a function returning void should "
430+
"be empty instead of an explicit !cir.void");
427431
return success();
428432
}
429433

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
// CIR Custom Parser/Printer Signatures
2121
//===----------------------------------------------------------------------===//
2222

23-
static mlir::ParseResult
24-
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
25-
bool &isVarArg);
26-
static void printFuncTypeArgs(mlir::AsmPrinter &p,
27-
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
23+
static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
24+
mlir::Type &optionalReturnTypes,
25+
llvm::SmallVector<mlir::Type> &params,
26+
bool &isVarArg);
27+
static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
28+
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
2829

2930
//===----------------------------------------------------------------------===//
3031
// Get autogenerated stuff
@@ -282,40 +283,55 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
282283
return get(llvm::to_vector(inputs), results[0], isVarArg());
283284
}
284285

285-
mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
286-
llvm::SmallVector<mlir::Type> &params,
287-
bool &isVarArg) {
286+
// A special parser is needed for function returning void to handle the missing
287+
// type.
288+
static mlir::ParseResult parseFuncTypeReturn(mlir::AsmParser &p,
289+
mlir::Type &optionalReturnType) {
290+
if (succeeded(p.parseOptionalArrow())) {
291+
// `->` found. It must be followed by the return type.
292+
return p.parseType(optionalReturnType);
293+
}
294+
// Function has `void` return in C++, no return in MLIR.
295+
optionalReturnType = {};
296+
return success();
297+
}
298+
299+
// A special pretty-printer for function returning or not a result.
300+
static void printFuncTypeReturn(mlir::AsmPrinter &p,
301+
mlir::Type optionalReturnType) {
302+
if (optionalReturnType)
303+
p << " -> " << optionalReturnType;
304+
}
305+
306+
static mlir::ParseResult
307+
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
308+
bool &isVarArg) {
288309
isVarArg = false;
289-
// `(` `)`
290-
if (succeeded(p.parseOptionalRParen()))
310+
if (failed(p.parseLParen()))
311+
return failure();
312+
if (succeeded(p.parseOptionalRParen())) {
313+
// `()` empty argument list
291314
return mlir::success();
292-
293-
// `(` `...` `)`
294-
if (succeeded(p.parseOptionalEllipsis())) {
295-
isVarArg = true;
296-
return p.parseRParen();
297315
}
298-
299-
// type (`,` type)* (`,` `...`)?
300-
mlir::Type type;
301-
if (p.parseType(type))
302-
return mlir::failure();
303-
params.push_back(type);
304-
while (succeeded(p.parseOptionalComma())) {
316+
do {
305317
if (succeeded(p.parseOptionalEllipsis())) {
318+
// `...`, which must be the last thing in the list.
306319
isVarArg = true;
307-
return p.parseRParen();
320+
break;
321+
} else {
322+
mlir::Type argType;
323+
if (failed(p.parseType(argType)))
324+
return failure();
325+
params.push_back(argType);
308326
}
309-
if (p.parseType(type))
310-
return mlir::failure();
311-
params.push_back(type);
312-
}
313-
327+
} while (succeeded(p.parseOptionalComma()));
314328
return p.parseRParen();
315329
}
316330

317-
void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
318-
bool isVarArg) {
331+
static void printFuncTypeArgs(mlir::AsmPrinter &p,
332+
mlir::ArrayRef<mlir::Type> params,
333+
bool isVarArg) {
334+
p << '(';
319335
llvm::interleaveComma(params, p,
320336
[&p](mlir::Type type) { p.printType(type); });
321337
if (isVarArg) {
@@ -326,11 +342,56 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
326342
p << ')';
327343
}
328344

345+
// Use a custom parser to handle the optional return and argument types without
346+
// an optional anchor.
347+
static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
348+
mlir::Type &optionalReturnType,
349+
llvm::SmallVector<mlir::Type> &params,
350+
bool &isVarArg) {
351+
if (failed(parseFuncTypeArgs(p, params, isVarArg)))
352+
return failure();
353+
return parseFuncTypeReturn(p, optionalReturnType);
354+
}
355+
356+
static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnType,
357+
mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
358+
printFuncTypeArgs(p, params, isVarArg);
359+
printFuncTypeReturn(p, optionalReturnType);
360+
}
361+
362+
/// Get the C-style return type of the function, which is !cir.void if the
363+
/// function returns nothing and the actual return type otherwise.
364+
mlir::Type FuncType::getReturnType() const {
365+
if (hasVoidReturn())
366+
return cir::VoidType::get(getContext());
367+
return getOptionalReturnType();
368+
}
369+
370+
/// Get the MLIR-style return type of the function, which is an empty
371+
/// ArrayRef if the function returns nothing and a single-element ArrayRef
372+
/// with the actual return type otherwise.
329373
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
330-
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
374+
if (hasVoidReturn())
375+
return {};
376+
// Can't use getOptionalReturnType() here because llvm::ArrayRef hold a
377+
// pointer to its elements and doesn't do lifetime extension. That would
378+
// result in returning a pointer to a temporary that has gone out of scope.
379+
return getImpl()->optionalReturnType;
331380
}
332381

333-
bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
382+
// Does the fuction type return nothing?
383+
bool FuncType::hasVoidReturn() const { return !getOptionalReturnType(); }
384+
385+
mlir::LogicalResult
386+
FuncType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
387+
llvm::ArrayRef<mlir::Type> argTypes, mlir::Type returnType,
388+
bool isVarArg) {
389+
if (returnType && mlir::isa<cir::VoidType>(returnType)) {
390+
emitError() << "!cir.func cannot have an explicit 'void' return type";
391+
return mlir::failure();
392+
}
393+
return mlir::success();
394+
}
334395

335396
//===----------------------------------------------------------------------===//
336397
// BoolType

clang/test/CIR/IR/func.cir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22

33
module {
44
// void empty() { }
5-
cir.func @empty() -> !cir.void {
5+
cir.func @empty() {
66
cir.return
77
}
8-
// CHECK: cir.func @empty() -> !cir.void {
8+
// CHECK: cir.func @empty() {
99
// CHECK: cir.return
1010
// CHECK: }
1111

1212
// void voidret() { return; }
13-
cir.func @voidret() -> !cir.void {
13+
cir.func @voidret() {
1414
cir.return
1515
}
16-
// CHECK: cir.func @voidret() -> !cir.void {
16+
// CHECK: cir.func @voidret() {
1717
// CHECK: cir.return
1818
// CHECK: }
1919

clang/test/CIR/IR/global.cir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
3030
cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
3131
cir.global @dp : !cir.ptr<!cir.double>
3232
cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
33-
cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
34-
cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
35-
cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
33+
cir.global @fp : !cir.ptr<!cir.func<()>>
34+
cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
35+
cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
3636
}
3737

3838
// CHECK: cir.global @c : !cir.int<s, 8>
@@ -64,6 +64,6 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
6464
// CHECK: cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
6565
// CHECK: cir.global @dp : !cir.ptr<!cir.double>
6666
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
67-
// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
68-
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
69-
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
67+
// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
68+
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
69+
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>

clang/test/CIR/func-simple.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s
33

44
void empty() { }
5-
// CHECK: cir.func @empty() -> !cir.void {
5+
// CHECK: cir.func @empty() {
66
// CHECK: cir.return
77
// CHECK: }
88

99
void voidret() { return; }
10-
// CHECK: cir.func @voidret() -> !cir.void {
10+
// CHECK: cir.func @voidret() {
1111
// CHECK: cir.return
1212
// CHECK: }
1313

clang/test/CIR/global-var-simple.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ char **cpp;
9292
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
9393

9494
void (*fp)();
95-
// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
95+
// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
9696

9797
int (*fpii)(int) = 0;
98-
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
98+
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
9999

100100
void (*fpvar)(int, ...);
101-
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
101+
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>

0 commit comments

Comments
 (0)