diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 761af3ced802..84f30a5ac2fe 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -3471,8 +3471,6 @@ def FuncOp : CIR_Op<"func", [ /// Returns the results types that the callable region produces when /// executed. llvm::ArrayRef getCallableResults() { - if (::llvm::isa(getFunctionType().getReturnType())) - return {}; return getFunctionType().getReturnTypes(); } @@ -3489,10 +3487,15 @@ def FuncOp : CIR_Op<"func", [ } /// Returns the argument types of this function. - llvm::ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } + llvm::ArrayRef getArgumentTypes() { + return getFunctionType().getInputs(); + } - /// Returns the result types of this function. - llvm::ArrayRef getResultTypes() { return getFunctionType().getReturnTypes(); } + /// Returns 0 or 1 result type of this function (0 in the case of a function + /// returing void) + llvm::ArrayRef getResultTypes() { + return getFunctionType().getReturnTypes(); + } /// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that /// the 'type' attribute is present and checks if it holds a function type. diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index c805b6887cf3..d3f49716301d 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -379,22 +379,27 @@ def CIR_FuncType : CIR_Type<"Func", "func"> { ```mlir !cir.func + !cir.func !cir.func !cir.func ``` }]; - let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, "mlir::Type":$returnType, + let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, ArrayRefParameter<"mlir::Type">:$returnTypes, "bool":$varArg); let assemblyFormat = [{ - `<` $returnType ` ` `(` custom($inputs, $varArg) `>` + `<` custom($returnTypes, $inputs, $varArg) `>` }]; let builders = [ + // Construct with an actual return type or explicit !cir.void TypeBuilderWithInferredContext<(ins "llvm::ArrayRef":$inputs, "mlir::Type":$returnType, CArg<"bool", "false">:$isVarArg), [{ - return $_get(returnType.getContext(), inputs, returnType, isVarArg); + return $_get(returnType.getContext(), inputs, + ::mlir::isa<::cir::VoidType>(returnType) ? llvm::ArrayRef{} + : llvm::ArrayRef{returnType}, + isVarArg); }]> ]; @@ -408,11 +413,11 @@ def CIR_FuncType : CIR_Type<"Func", "func"> { /// Returns the number of arguments to the function. unsigned getNumInputs() const { return getInputs().size(); } - /// Returns the result type of the function as an ArrayRef, enabling better - /// integration with generic MLIR utilities. - llvm::ArrayRef getReturnTypes() const; + /// Returns the result type of the function as an actual return type or + /// explicit !cir.void + mlir::Type getReturnType() const; - /// Returns whether the function is returns void. + /// Returns whether the function returns void. bool isVoid() const; /// Returns a clone of this function type with the given argument diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp index 5483a0f805a5..d3c6814e37ff 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp @@ -271,7 +271,7 @@ mlir::Type CIRGenTypes::ConvertFunctionTypeInternal(QualType QFT) { assert(QFT.isCanonical()); const Type *Ty = QFT.getTypePtr(); const FunctionType *FT = cast(QFT.getTypePtr()); - // First, check whether we can build the full fucntion type. If the function + // First, check whether we can build the full function type. If the function // type depends on an incomplete type (e.g. a struct or enum), we cannot lower // the function type. assert(isFuncTypeConvertible(FT) && "NYI"); diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index dcd993582971..f9674e6e4965 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -2218,6 +2218,26 @@ void cir::FuncOp::build(OpBuilder &builder, OperationState &result, getResAttrsAttrName(result.name)); } +// A specific version of function_interface_impl::parseFunctionSignature able to +// handle the "-> !void" special fake return type. +static ParseResult +parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + if (function_interface_impl::parseFunctionArgumentList(parser, allowVariadic, + arguments, isVariadic)) + return failure(); + if (succeeded(parser.parseOptionalArrow())) { + if (parser.parseOptionalExclamationKeyword("!void").succeeded()) + // This is just an empty return type and attribute. + return success(); + return function_interface_impl::parseFunctionResultList(parser, resultTypes, + resultAttrs); + } + return success(); +} + ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) { llvm::SMLoc loc = parser.getCurrentLocation(); @@ -2278,9 +2298,8 @@ ParseResult cir::FuncOp::parse(OpAsmParser &parser, OperationState &state) { // Parse the function signature. bool isVariadic = false; - if (function_interface_impl::parseFunctionSignature( - parser, /*allowVariadic=*/true, arguments, isVariadic, resultTypes, - resultAttrs)) + if (parseFunctionSignature(parser, /*allowVariadic=*/true, arguments, + isVariadic, resultTypes, resultAttrs)) return failure(); for (auto &arg : arguments) @@ -2483,13 +2502,8 @@ void cir::FuncOp::print(OpAsmPrinter &p) { p.printSymbolName(getSymName()); auto fnType = getFunctionType(); llvm::SmallVector resultTypes; - if (!fnType.isVoid()) - function_interface_impl::printFunctionSignature( - p, *this, fnType.getInputs(), fnType.isVarArg(), - fnType.getReturnTypes()); - else - function_interface_impl::printFunctionSignature( - p, *this, fnType.getInputs(), fnType.isVarArg(), {}); + function_interface_impl::printFunctionSignature( + p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes()); if (mlir::ArrayAttr annotations = getAnnotationsAttr()) { p << ' '; @@ -2558,6 +2572,11 @@ LogicalResult cir::FuncOp::verifyType() { if (!getNoProto() && type.isVarArg() && type.getNumInputs() == 0) return emitError() << "prototyped function must have at least one non-variadic input"; + if (auto rt = type.getReturnTypes(); + !rt.empty() && mlir::isa(rt.front())) + return emitOpError("The return type for a function returning void should " + "be empty instead of an explicit !cir.void"); + return success(); } diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index df89584fd3a9..2b17b048f6c6 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" +#include #include using cir::MissingFeatures; @@ -42,13 +43,16 @@ using cir::MissingFeatures; //===----------------------------------------------------------------------===// static mlir::ParseResult -parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector ¶ms, - bool &isVarArg); -static void printFuncTypeArgs(mlir::AsmPrinter &p, - mlir::ArrayRef params, bool isVarArg); +parseFuncType(mlir::AsmParser &p, llvm::SmallVector &returnTypes, + llvm::SmallVector ¶ms, bool &isVarArg); + +static void printFuncType(mlir::AsmPrinter &p, + mlir::ArrayRef returnTypes, + mlir::ArrayRef params, bool isVarArg); static mlir::ParseResult parsePointerAddrSpace(mlir::AsmParser &p, mlir::Attribute &addrSpaceAttr); + static void printPointerAddrSpace(mlir::AsmPrinter &p, mlir::Attribute addrSpaceAttr); @@ -913,9 +917,46 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const { return get(llvm::to_vector(inputs), results[0], isVarArg()); } -mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p, - llvm::SmallVector ¶ms, - bool &isVarArg) { +// A special parser is needed for function returning void to consume the "!void" +// returned type in the case there is no alias defined. +static mlir::ParseResult +parseFuncTypeReturn(mlir::AsmParser &p, + llvm::SmallVector &returnTypes) { + if (p.parseOptionalExclamationKeyword("!void").succeeded()) + // !void means no return type. + return p.parseLParen(); + if (succeeded(p.parseOptionalLParen())) + // If we have already a '(', the function has no return type + return mlir::success(); + + mlir::Type type; + auto result = p.parseOptionalType(type); + if (!result.has_value()) + return mlir::failure(); + if (failed(*result) || isa(type)) + // No return type specified. + return p.parseLParen(); + // Otherwise use the actual type. + returnTypes.push_back(type); + return p.parseLParen(); +} + +// A special pretty-printer for function returning void to emit a "!void" +// returned type. Note that there is no real type used here since it does not +// appear in the IR and thus the alias might not be defined and cannot be +// referred to. This is why this is a pure syntactic-sugar string which is used. +static void printFuncTypeReturn(mlir::AsmPrinter &p, + mlir::ArrayRef returnTypes) { + if (returnTypes.empty()) + // Pretty-print no return type as "!void" + p << "!void "; + else + p << returnTypes << ' '; +} + +static mlir::ParseResult +parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector ¶ms, + bool &isVarArg) { isVarArg = false; // `(` `)` if (succeeded(p.parseOptionalRParen())) @@ -945,8 +986,10 @@ mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p, return p.parseRParen(); } -void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef params, - bool isVarArg) { +static void printFuncTypeArgs(mlir::AsmPrinter &p, + mlir::ArrayRef params, + bool isVarArg) { + p << '('; llvm::interleaveComma(params, p, [&p](mlir::Type type) { p.printType(type); }); if (isVarArg) { @@ -957,11 +1000,37 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef params, p << ')'; } -llvm::ArrayRef FuncType::getReturnTypes() const { - return static_cast(getImpl())->returnType; +static mlir::ParseResult +parseFuncType(mlir::AsmParser &p, llvm::SmallVector &returnTypes, + llvm::SmallVector ¶ms, bool &isVarArg) { + if (failed(parseFuncTypeReturn(p, returnTypes))) + return failure(); + return parseFuncTypeArgs(p, params, isVarArg); +} + +static void printFuncType(mlir::AsmPrinter &p, + mlir::ArrayRef returnTypes, + mlir::ArrayRef params, bool isVarArg) { + printFuncTypeReturn(p, returnTypes); + printFuncTypeArgs(p, params, isVarArg); } -bool FuncType::isVoid() const { return mlir::isa(getReturnType()); } +// Return the actual return type or an explicit !cir.void if the function does +// not return anything +mlir::Type FuncType::getReturnType() const { + if (isVoid()) + return cir::VoidType::get(getContext()); + return static_cast(getImpl())->returnTypes.front(); +} + +bool FuncType::isVoid() const { + auto rt = static_cast(getImpl())->returnTypes; + assert(rt.empty() || + !mlir::isa(rt.front()) && + "The return type for a function returning void should be empty " + "instead of a real !cir.void"); + return rt.empty(); +} //===----------------------------------------------------------------------===// // MethodType Definitions diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerTypes.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerTypes.cpp index 0c2233ef84c9..d655ae9023dd 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerTypes.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerTypes.cpp @@ -109,7 +109,7 @@ FuncType LowerTypes::getFunctionType(const LowerFunctionInfo &FI) { } } - return FuncType::get(getMLIRContext(), ArgTypes, resultType, FI.isVariadic()); + return FuncType::get(ArgTypes, resultType, FI.isVariadic()); } /// Convert a CIR type to its ABI-specific default form. diff --git a/clang/test/CIR/IR/being_and_nothingness.cir b/clang/test/CIR/IR/being_and_nothingness.cir new file mode 100644 index 000000000000..311acb4893dc --- /dev/null +++ b/clang/test/CIR/IR/being_and_nothingness.cir @@ -0,0 +1,35 @@ +// RUN: cir-opt %s | FileCheck %s +// Exercise different ways to encode a function returning void +!s32i = !cir.int +!fnptr1 = !cir.ptr> +// Note there is no !void alias defined +!fnptr2 = !cir.ptr> +!fnptr3 = !cir.ptr> +module { + cir.func @ind1(%fnptr: !fnptr1, %a : !s32i) { + // CHECK: cir.func @ind1(%arg0: !cir.ptr>, %arg1: !s32i) { + cir.return + } + + cir.func @ind2(%fnptr: !fnptr2, %a : !s32i) { + // CHECK: cir.func @ind2(%arg0: !cir.ptr>, %arg1: !s32i) { + cir.return + } + cir.func @ind3(%fnptr: !fnptr3, %a : !s32i) { + // CHECK: cir.func @ind3(%arg0: !cir.ptr>, %arg1: !s32i) { + cir.return + } + cir.func @f1() -> !cir.void { + // CHECK: cir.func @f1() { + cir.return + } + // Note there is no !void alias defined + cir.func @f2() -> !void { + // CHECK: cir.func @f2() { + cir.return + } + cir.func @f3() { + // CHECK: cir.func @f3() { + cir.return + } +} diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index a7222794f320..c7b223f5be87 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -922,6 +922,9 @@ class AsmParser { /// Parse an optional keyword or string. virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; + /// Parse the given exclamation-prefixed keyword if present. + virtual ParseResult parseOptionalExclamationKeyword(StringRef keyword) = 0; + //===--------------------------------------------------------------------===// // Attribute/Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h index a5e6963e4e66..110025bc89f5 100644 --- a/mlir/include/mlir/Interfaces/FunctionImplementation.h +++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h @@ -64,6 +64,28 @@ parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, bool &isVariadic, SmallVectorImpl &resultTypes, SmallVectorImpl &resultAttrs); +/// Parse a function argument list using `parser`. The `allowVariadic` argument +/// indicates whether functions with variadic arguments are supported. The +/// trailing arguments are populated by this function with names, types, +/// attributes and locations of the arguments. +ParseResult +parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic); + +/// Parse a function result list using `parser`. +/// +/// function-result-list ::= function-result-list-parens +/// | non-function-type +/// function-result-list-parens ::= `(` `)` +/// | `(` function-result-list-no-parens `)` +/// function-result-list-no-parens ::= function-result (`,` function-result)* +/// function-result ::= type attribute-dict? +/// +ParseResult +parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); + /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of /// input and output types. The parser sets the `typeAttrName` attribute to the diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index d5b72d63813a..8c7ce16fe54d 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -396,6 +396,19 @@ class AsmParserImpl : public BaseT { return parseOptionalString(result); } + /// Parse the given exclamation-prefixed keyword if present. + ParseResult parseOptionalExclamationKeyword(StringRef keyword) override { + if (parser.getToken().isCodeCompletion()) + return parser.codeCompleteOptionalTokens(keyword); + + // Check that the current token has the same spelling. + if (!parser.getToken().is(Token::Kind::exclamation_identifier) || + parser.getTokenSpelling() != keyword) + return failure(); + parser.consumeToken(); + return success(); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp index 988feee665fe..9922e3c28eab 100644 --- a/mlir/lib/Interfaces/FunctionImplementation.cpp +++ b/mlir/lib/Interfaces/FunctionImplementation.cpp @@ -13,10 +13,9 @@ using namespace mlir; -static ParseResult -parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, - SmallVectorImpl &arguments, - bool &isVariadic) { +ParseResult function_interface_impl::parseFunctionArgumentList( + OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, bool &isVariadic) { // Parse the function arguments. The argument list either has to consistently // have ssa-id's followed by types, or just be a type list. It isn't ok to @@ -79,9 +78,9 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, /// function-result-list-no-parens ::= function-result (`,` function-result)* /// function-result ::= type attribute-dict? /// -static ParseResult -parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, - SmallVectorImpl &resultAttrs) { +ParseResult function_interface_impl::parseFunctionResultList( + OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { if (failed(parser.parseOptionalLParen())) { // We already know that there is no `(`, so parse a type. // Because there is no `(`, it cannot be a function type.