3333#include " llvm/ADT/TypeSwitch.h"
3434#include " llvm/Support/ErrorHandling.h"
3535#include " llvm/Support/MathExtras.h"
36+ #include < cassert>
3637#include < optional>
3738
3839using cir::MissingFeatures;
@@ -41,12 +42,13 @@ using cir::MissingFeatures;
4142// CIR Custom Parser/Printer Signatures
4243// ===----------------------------------------------------------------------===//
4344
44- static mlir::ParseResult
45- parseFuncTypeArgs (mlir::AsmParser &p, llvm::SmallVector<mlir::Type> ¶ms,
46- bool &isVarArg);
47- static void printFuncTypeArgs (mlir::AsmPrinter &p,
48- mlir::ArrayRef<mlir::Type> params, bool isVarArg);
45+ static mlir::ParseResult parseFuncType (mlir::AsmParser &p,
46+ mlir::Type &optionalReturnTypes,
47+ llvm::SmallVector<mlir::Type> ¶ms,
48+ bool &isVarArg);
4949
50+ static void printFuncType (mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
51+ mlir::ArrayRef<mlir::Type> params, bool isVarArg);
5052static mlir::ParseResult parsePointerAddrSpace (mlir::AsmParser &p,
5153 mlir::Attribute &addrSpaceAttr);
5254static void printPointerAddrSpace (mlir::AsmPrinter &p,
@@ -913,9 +915,38 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
913915 return get (llvm::to_vector (inputs), results[0 ], isVarArg ());
914916}
915917
916- mlir::ParseResult parseFuncTypeArgs (mlir::AsmParser &p,
917- llvm::SmallVector<mlir::Type> ¶ms,
918- bool &isVarArg) {
918+ // A special parser is needed for function returning void to handle the missing
919+ // type.
920+ static mlir::ParseResult parseFuncTypeReturn (mlir::AsmParser &p,
921+ mlir::Type &optionalReturnType) {
922+ if (succeeded (p.parseOptionalLParen ())) {
923+ // If we have already a '(', the function has no return type
924+ optionalReturnType = {};
925+ return mlir::success ();
926+ }
927+ mlir::Type type;
928+ if (p.parseType (type))
929+ return mlir::failure ();
930+ if (isa<cir::VoidType>(type))
931+ // An explicit !cir.void means also no return type.
932+ optionalReturnType = {};
933+ else
934+ // Otherwise use the actual type.
935+ optionalReturnType = type;
936+ return p.parseLParen ();
937+ }
938+
939+ // A special pretty-printer for function returning or not a result.
940+ static void printFuncTypeReturn (mlir::AsmPrinter &p,
941+ mlir::Type optionalReturnType) {
942+ if (optionalReturnType)
943+ p << optionalReturnType << ' ' ;
944+ p << ' (' ;
945+ }
946+
947+ static mlir::ParseResult
948+ parseFuncTypeArgs (mlir::AsmParser &p, llvm::SmallVector<mlir::Type> ¶ms,
949+ bool &isVarArg) {
919950 isVarArg = false ;
920951 // `(` `)`
921952 if (succeeded (p.parseOptionalRParen ()))
@@ -945,8 +976,9 @@ mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
945976 return p.parseRParen ();
946977}
947978
948- void printFuncTypeArgs (mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
949- bool isVarArg) {
979+ static void printFuncTypeArgs (mlir::AsmPrinter &p,
980+ mlir::ArrayRef<mlir::Type> params,
981+ bool isVarArg) {
950982 llvm::interleaveComma (params, p,
951983 [&p](mlir::Type type) { p.printType (type); });
952984 if (isVarArg) {
@@ -957,11 +989,49 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
957989 p << ' )' ;
958990}
959991
992+ // Use a custom parser to handle the optional return and argument types without
993+ // an optional anchor.
994+ static mlir::ParseResult parseFuncType (mlir::AsmParser &p,
995+ mlir::Type &optionalReturnTypes,
996+ llvm::SmallVector<mlir::Type> ¶ms,
997+ bool &isVarArg) {
998+ if (failed (parseFuncTypeReturn (p, optionalReturnTypes)))
999+ return failure ();
1000+ return parseFuncTypeArgs (p, params, isVarArg);
1001+ }
1002+
1003+ static void printFuncType (mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
1004+ mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
1005+ printFuncTypeReturn (p, optionalReturnTypes);
1006+ printFuncTypeArgs (p, params, isVarArg);
1007+ }
1008+
1009+ // Return the actual return type or an explicit !cir.void if the function does
1010+ // not return anything
1011+ mlir::Type FuncType::getReturnType () const {
1012+ if (isVoid ())
1013+ return cir::VoidType::get (getContext ());
1014+ return static_cast <detail::FuncTypeStorage *>(getImpl ())->optionalReturnType ;
1015+ }
1016+
1017+ // / Returns the result type of the function as an ArrayRef, enabling better
1018+ // / integration with generic MLIR utilities.
9601019llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes () const {
961- return static_cast <detail::FuncTypeStorage *>(getImpl ())->returnType ;
1020+ if (isVoid ())
1021+ return {};
1022+ return static_cast <detail::FuncTypeStorage *>(getImpl ())->optionalReturnType ;
9621023}
9631024
964- bool FuncType::isVoid () const { return mlir::isa<VoidType>(getReturnType ()); }
1025+ // Whether the function returns void
1026+ bool FuncType::isVoid () const {
1027+ auto rt =
1028+ static_cast <detail::FuncTypeStorage *>(getImpl ())->optionalReturnType ;
1029+ assert (!rt ||
1030+ !mlir::isa<cir::VoidType>(rt) &&
1031+ " The return type for a function returning void should be empty "
1032+ " instead of a real !cir.void" );
1033+ return !rt;
1034+ }
9651035
9661036// ===----------------------------------------------------------------------===//
9671037// MethodType Definitions
0 commit comments