Skip to content

Commit 568b515

Browse files
authored
[CIR] Remove the !cir.void return type for functions returning void (#1203)
C/C++ functions returning void had an explicit !cir.void return type while not having any returned value, which was breaking a lot of MLIR invariants when the CIR dialect is used in a greater context, for example with the inliner. Now, a C/C++ function returning void has not return type and no return values, which does not break the MLIR invariant about the same number of return types and returned values. This change keeps the same parsing/pretty-printed syntax as before for compatibility.
1 parent d1d43f0 commit 568b515

File tree

11 files changed

+211
-43
lines changed

11 files changed

+211
-43
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3474,8 +3474,6 @@ def FuncOp : CIR_Op<"func", [
34743474
/// Returns the results types that the callable region produces when
34753475
/// executed.
34763476
llvm::ArrayRef<mlir::Type> getCallableResults() {
3477-
if (::llvm::isa<cir::VoidType>(getFunctionType().getReturnType()))
3478-
return {};
34793477
return getFunctionType().getReturnTypes();
34803478
}
34813479

@@ -3492,10 +3490,15 @@ def FuncOp : CIR_Op<"func", [
34923490
}
34933491

34943492
/// Returns the argument types of this function.
3495-
llvm::ArrayRef<mlir::Type> getArgumentTypes() { return getFunctionType().getInputs(); }
3493+
llvm::ArrayRef<mlir::Type> getArgumentTypes() {
3494+
return getFunctionType().getInputs();
3495+
}
34963496

3497-
/// Returns the result types of this function.
3498-
llvm::ArrayRef<mlir::Type> getResultTypes() { return getFunctionType().getReturnTypes(); }
3497+
/// Returns 0 or 1 result type of this function (0 in the case of a function
3498+
/// returing void)
3499+
llvm::ArrayRef<mlir::Type> getResultTypes() {
3500+
return getFunctionType().getReturnTypes();
3501+
}
34993502

35003503
/// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that
35013504
/// the 'type' attribute is present and checks if it holds a function type.

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -379,22 +379,27 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
379379

380380
```mlir
381381
!cir.func<!bool ()>
382+
!cir.func<!cir.void ()>
382383
!cir.func<!s32i (!s8i, !s8i)>
383384
!cir.func<!s32i (!s32i, ...)>
384385
```
385386
}];
386387

387-
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, "mlir::Type":$returnType,
388+
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, ArrayRefParameter<"mlir::Type">:$returnTypes,
388389
"bool":$varArg);
389390
let assemblyFormat = [{
390-
`<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
391+
`<` custom<FuncType>($returnTypes, $inputs, $varArg) `>`
391392
}];
392393

393394
let builders = [
395+
// Construct with an actual return type or explicit !cir.void
394396
TypeBuilderWithInferredContext<(ins
395397
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
396398
CArg<"bool", "false">:$isVarArg), [{
397-
return $_get(returnType.getContext(), inputs, returnType, isVarArg);
399+
return $_get(returnType.getContext(), inputs,
400+
::mlir::isa<::cir::VoidType>(returnType) ? llvm::ArrayRef<mlir::Type>{}
401+
: llvm::ArrayRef{returnType},
402+
isVarArg);
398403
}]>
399404
];
400405

@@ -408,11 +413,11 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
408413
/// Returns the number of arguments to the function.
409414
unsigned getNumInputs() const { return getInputs().size(); }
410415

411-
/// Returns the result type of the function as an ArrayRef, enabling better
412-
/// integration with generic MLIR utilities.
413-
llvm::ArrayRef<mlir::Type> getReturnTypes() const;
416+
/// Returns the result type of the function as an actual return type or
417+
/// explicit !cir.void
418+
mlir::Type getReturnType() const;
414419

415-
/// Returns whether the function is returns void.
420+
/// Returns whether the function returns void.
416421
bool isVoid() const;
417422

418423
/// Returns a clone of this function type with the given argument

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ mlir::Type CIRGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
271271
assert(QFT.isCanonical());
272272
const Type *Ty = QFT.getTypePtr();
273273
const FunctionType *FT = cast<FunctionType>(QFT.getTypePtr());
274-
// First, check whether we can build the full fucntion type. If the function
274+
// First, check whether we can build the full function type. If the function
275275
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
276276
// the function type.
277277
assert(isFuncTypeConvertible(FT) && "NYI");

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,6 +2224,26 @@ void cir::FuncOp::build(OpBuilder &builder, OperationState &result,
22242224
getResAttrsAttrName(result.name));
22252225
}
22262226

2227+
// A specific version of function_interface_impl::parseFunctionSignature able to
2228+
// handle the "-> !void" special fake return type.
2229+
static ParseResult
2230+
parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
2231+
SmallVectorImpl<OpAsmParser::Argument> &arguments,
2232+
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
2233+
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
2234+
if (function_interface_impl::parseFunctionArgumentList(parser, allowVariadic,
2235+
arguments, isVariadic))
2236+
return failure();
2237+
if (succeeded(parser.parseOptionalArrow())) {
2238+
if (parser.parseOptionalExclamationKeyword("!void").succeeded())
2239+
// This is just an empty return type and attribute.
2240+
return success();
2241+
return function_interface_impl::parseFunctionResultList(parser, resultTypes,
2242+
resultAttrs);
2243+
}
2244+
return success();
2245+
}
2246+
22272247
ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
22282248
llvm::SMLoc loc = parser.getCurrentLocation();
22292249

@@ -2284,9 +2304,8 @@ ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
22842304

22852305
// Parse the function signature.
22862306
bool isVariadic = false;
2287-
if (function_interface_impl::parseFunctionSignature(
2288-
parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes,
2289-
resultAttrs))
2307+
if (parseFunctionSignature(parser, /*allowVariadic=*/true, arguments,
2308+
isVariadic, resultTypes, resultAttrs))
22902309
return failure();
22912310

22922311
for (auto &arg : arguments)
@@ -2489,13 +2508,8 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
24892508
p.printSymbolName(getSymName());
24902509
auto fnType = getFunctionType();
24912510
llvm::SmallVector<Type, 1> resultTypes;
2492-
if (!fnType.isVoid())
2493-
function_interface_impl::printFunctionSignature(
2494-
p, *this, fnType.getInputs(), fnType.isVarArg(),
2495-
fnType.getReturnTypes());
2496-
else
2497-
function_interface_impl::printFunctionSignature(
2498-
p, *this, fnType.getInputs(), fnType.isVarArg(), {});
2511+
function_interface_impl::printFunctionSignature(
2512+
p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());
24992513

25002514
if (mlir::ArrayAttr annotations = getAnnotationsAttr()) {
25012515
p << ' ';
@@ -2564,6 +2578,11 @@ LogicalResult cir::FuncOp::verifyType() {
25642578
if (!getNoProto() && type.isVarArg() && type.getNumInputs() == 0)
25652579
return emitError()
25662580
<< "prototyped function must have at least one non-variadic input";
2581+
if (auto rt = type.getReturnTypes();
2582+
!rt.empty() && mlir::isa<cir::VoidType>(rt.front()))
2583+
return emitOpError("The return type for a function returning void should "
2584+
"be empty instead of an explicit !cir.void");
2585+
25672586
return success();
25682587
}
25692588

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/ADT/TypeSwitch.h"
3434
#include "llvm/Support/ErrorHandling.h"
3535
#include "llvm/Support/MathExtras.h"
36+
#include <cassert>
3637
#include <optional>
3738

3839
using cir::MissingFeatures;
@@ -42,13 +43,16 @@ using cir::MissingFeatures;
4243
//===----------------------------------------------------------------------===//
4344

4445
static mlir::ParseResult
45-
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
46-
bool &isVarArg);
47-
static void printFuncTypeArgs(mlir::AsmPrinter &p,
48-
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
46+
parseFuncType(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &returnTypes,
47+
llvm::SmallVector<mlir::Type> &params, bool &isVarArg);
48+
49+
static void printFuncType(mlir::AsmPrinter &p,
50+
mlir::ArrayRef<mlir::Type> returnTypes,
51+
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
4952

5053
static mlir::ParseResult parsePointerAddrSpace(mlir::AsmParser &p,
5154
mlir::Attribute &addrSpaceAttr);
55+
5256
static void printPointerAddrSpace(mlir::AsmPrinter &p,
5357
mlir::Attribute addrSpaceAttr);
5458

@@ -913,9 +917,46 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
913917
return get(llvm::to_vector(inputs), results[0], isVarArg());
914918
}
915919

916-
mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
917-
llvm::SmallVector<mlir::Type> &params,
918-
bool &isVarArg) {
920+
// A special parser is needed for function returning void to consume the "!void"
921+
// returned type in the case there is no alias defined.
922+
static mlir::ParseResult
923+
parseFuncTypeReturn(mlir::AsmParser &p,
924+
llvm::SmallVector<mlir::Type> &returnTypes) {
925+
if (p.parseOptionalExclamationKeyword("!void").succeeded())
926+
// !void means no return type.
927+
return p.parseLParen();
928+
if (succeeded(p.parseOptionalLParen()))
929+
// If we have already a '(', the function has no return type
930+
return mlir::success();
931+
932+
mlir::Type type;
933+
auto result = p.parseOptionalType(type);
934+
if (!result.has_value())
935+
return mlir::failure();
936+
if (failed(*result) || isa<cir::VoidType>(type))
937+
// No return type specified.
938+
return p.parseLParen();
939+
// Otherwise use the actual type.
940+
returnTypes.push_back(type);
941+
return p.parseLParen();
942+
}
943+
944+
// A special pretty-printer for function returning void to emit a "!void"
945+
// returned type. Note that there is no real type used here since it does not
946+
// appear in the IR and thus the alias might not be defined and cannot be
947+
// referred to. This is why this is a pure syntactic-sugar string which is used.
948+
static void printFuncTypeReturn(mlir::AsmPrinter &p,
949+
mlir::ArrayRef<mlir::Type> returnTypes) {
950+
if (returnTypes.empty())
951+
// Pretty-print no return type as "!void"
952+
p << "!void ";
953+
else
954+
p << returnTypes << ' ';
955+
}
956+
957+
static mlir::ParseResult
958+
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
959+
bool &isVarArg) {
919960
isVarArg = false;
920961
// `(` `)`
921962
if (succeeded(p.parseOptionalRParen()))
@@ -945,8 +986,10 @@ mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
945986
return p.parseRParen();
946987
}
947988

948-
void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
949-
bool isVarArg) {
989+
static void printFuncTypeArgs(mlir::AsmPrinter &p,
990+
mlir::ArrayRef<mlir::Type> params,
991+
bool isVarArg) {
992+
p << '(';
950993
llvm::interleaveComma(params, p,
951994
[&p](mlir::Type type) { p.printType(type); });
952995
if (isVarArg) {
@@ -957,11 +1000,37 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
9571000
p << ')';
9581001
}
9591002

960-
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
961-
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
1003+
static mlir::ParseResult
1004+
parseFuncType(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &returnTypes,
1005+
llvm::SmallVector<mlir::Type> &params, bool &isVarArg) {
1006+
if (failed(parseFuncTypeReturn(p, returnTypes)))
1007+
return failure();
1008+
return parseFuncTypeArgs(p, params, isVarArg);
1009+
}
1010+
1011+
static void printFuncType(mlir::AsmPrinter &p,
1012+
mlir::ArrayRef<mlir::Type> returnTypes,
1013+
mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
1014+
printFuncTypeReturn(p, returnTypes);
1015+
printFuncTypeArgs(p, params, isVarArg);
9621016
}
9631017

964-
bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
1018+
// Return the actual return type or an explicit !cir.void if the function does
1019+
// not return anything
1020+
mlir::Type FuncType::getReturnType() const {
1021+
if (isVoid())
1022+
return cir::VoidType::get(getContext());
1023+
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnTypes.front();
1024+
}
1025+
1026+
bool FuncType::isVoid() const {
1027+
auto rt = static_cast<detail::FuncTypeStorage *>(getImpl())->returnTypes;
1028+
assert(rt.empty() ||
1029+
!mlir::isa<cir::VoidType>(rt.front()) &&
1030+
"The return type for a function returning void should be empty "
1031+
"instead of a real !cir.void");
1032+
return rt.empty();
1033+
}
9651034

9661035
//===----------------------------------------------------------------------===//
9671036
// MethodType Definitions

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ FuncType LowerTypes::getFunctionType(const LowerFunctionInfo &FI) {
109109
}
110110
}
111111

112-
return FuncType::get(getMLIRContext(), ArgTypes, resultType, FI.isVariadic());
112+
return FuncType::get(ArgTypes, resultType, FI.isVariadic());
113113
}
114114

115115
/// Convert a CIR type to its ABI-specific default form.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: cir-opt %s | FileCheck %s
2+
// Exercise different ways to encode a function returning void
3+
!s32i = !cir.int<s, 32>
4+
!fnptr1 = !cir.ptr<!cir.func<!cir.void(!s32i)>>
5+
// Note there is no !void alias defined
6+
!fnptr2 = !cir.ptr<!cir.func<!void(!s32i)>>
7+
!fnptr3 = !cir.ptr<!cir.func<(!s32i)>>
8+
module {
9+
cir.func @ind1(%fnptr: !fnptr1, %a : !s32i) {
10+
// CHECK: cir.func @ind1(%arg0: !cir.ptr<!cir.func<!void (!s32i)>>, %arg1: !s32i) {
11+
cir.return
12+
}
13+
14+
cir.func @ind2(%fnptr: !fnptr2, %a : !s32i) {
15+
// CHECK: cir.func @ind2(%arg0: !cir.ptr<!cir.func<!void (!s32i)>>, %arg1: !s32i) {
16+
cir.return
17+
}
18+
cir.func @ind3(%fnptr: !fnptr3, %a : !s32i) {
19+
// CHECK: cir.func @ind3(%arg0: !cir.ptr<!cir.func<!void (!s32i)>>, %arg1: !s32i) {
20+
cir.return
21+
}
22+
cir.func @f1() -> !cir.void {
23+
// CHECK: cir.func @f1() {
24+
cir.return
25+
}
26+
// Note there is no !void alias defined
27+
cir.func @f2() -> !void {
28+
// CHECK: cir.func @f2() {
29+
cir.return
30+
}
31+
cir.func @f3() {
32+
// CHECK: cir.func @f3() {
33+
cir.return
34+
}
35+
}

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,9 @@ class AsmParser {
922922
/// Parse an optional keyword or string.
923923
virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
924924

925+
/// Parse the given exclamation-prefixed keyword if present.
926+
virtual ParseResult parseOptionalExclamationKeyword(StringRef keyword) = 0;
927+
925928
//===--------------------------------------------------------------------===//
926929
// Attribute/Type Parsing
927930
//===--------------------------------------------------------------------===//

mlir/include/mlir/Interfaces/FunctionImplementation.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,28 @@ parseFunctionSignature(OpAsmParser &parser, bool allowVariadic,
6464
bool &isVariadic, SmallVectorImpl<Type> &resultTypes,
6565
SmallVectorImpl<DictionaryAttr> &resultAttrs);
6666

67+
/// Parse a function argument list using `parser`. The `allowVariadic` argument
68+
/// indicates whether functions with variadic arguments are supported. The
69+
/// trailing arguments are populated by this function with names, types,
70+
/// attributes and locations of the arguments.
71+
ParseResult
72+
parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
73+
SmallVectorImpl<OpAsmParser::Argument> &arguments,
74+
bool &isVariadic);
75+
76+
/// Parse a function result list using `parser`.
77+
///
78+
/// function-result-list ::= function-result-list-parens
79+
/// | non-function-type
80+
/// function-result-list-parens ::= `(` `)`
81+
/// | `(` function-result-list-no-parens `)`
82+
/// function-result-list-no-parens ::= function-result (`,` function-result)*
83+
/// function-result ::= type attribute-dict?
84+
///
85+
ParseResult
86+
parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
87+
SmallVectorImpl<DictionaryAttr> &resultAttrs);
88+
6789
/// Parser implementation for function-like operations. Uses
6890
/// `funcTypeBuilder` to construct the custom function type given lists of
6991
/// input and output types. The parser sets the `typeAttrName` attribute to the

mlir/lib/AsmParser/AsmParserImpl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,19 @@ class AsmParserImpl : public BaseT {
396396
return parseOptionalString(result);
397397
}
398398

399+
/// Parse the given exclamation-prefixed keyword if present.
400+
ParseResult parseOptionalExclamationKeyword(StringRef keyword) override {
401+
if (parser.getToken().isCodeCompletion())
402+
return parser.codeCompleteOptionalTokens(keyword);
403+
404+
// Check that the current token has the same spelling.
405+
if (!parser.getToken().is(Token::Kind::exclamation_identifier) ||
406+
parser.getTokenSpelling() != keyword)
407+
return failure();
408+
parser.consumeToken();
409+
return success();
410+
}
411+
399412
//===--------------------------------------------------------------------===//
400413
// Attribute Parsing
401414
//===--------------------------------------------------------------------===//

0 commit comments

Comments
 (0)