Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -287,32 +287,43 @@ def CIR_BoolType :
def CIR_FuncType : CIR_Type<"Func", "func"> {
let summary = "CIR function type";
let description = [{
The `!cir.func` is a function type. It consists of a single return type, a
list of parameter types and can optionally be variadic.
The `!cir.func` is a function type. It consists of an optional return type,
a list of parameter types and can optionally be variadic.

Example:

```mlir
!cir.func<!bool ()>
!cir.func<!s32i (!s8i, !s8i)>
!cir.func<!s32i (!s32i, ...)>
!cir.func<()>
!cir.func<() -> bool>
!cir.func<(!s8i, !s8i)>
!cir.func<(!s8i, !s8i) -> !s32i>
!cir.func<(!s32i, ...) -> !s32i>
```
}];

let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs,
"mlir::Type":$returnType, "bool":$varArg);
"mlir::Type":$optionalReturnType, "bool":$varArg);
// Use a custom parser to handle the argument types and optional return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Use a custom parser to handle the argument types and optional return
// Use a custom parser to handle the argument types and optional return.

let assemblyFormat = [{
`<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
`<` custom<FuncType>($optionalReturnType, $inputs, $varArg) `>`
}];

let builders = [
// Create a FuncType, converting the return type from C-style to
// MLIR-style. If the given return type is `cir::VoidType`, ignore it
// and create the FuncType with no return type, which is how MLIR
// represents function types.
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
CArg<"bool", "false">:$isVarArg), [{
return $_get(returnType.getContext(), inputs, returnType, isVarArg);
return $_get(returnType.getContext(), inputs,
mlir::isa<cir::VoidType>(returnType) ? nullptr : returnType,
isVarArg);
}]>
];

let genVerifyDecl = 1;

let extraClassDeclaration = [{
/// Returns whether the function is variadic.
bool isVarArg() const { return getVarArg(); }
Expand All @@ -323,12 +334,17 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
/// Returns the number of arguments to the function.
unsigned getNumInputs() const { return getInputs().size(); }

/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
/// Get the C-style return type of the function, which is !cir.void if the
/// function returns nothing and the actual return type otherwise.
mlir::Type getReturnType() const;

/// Get the MLIR-style return type of the function, which is an empty
/// ArrayRef if the function returns nothing and a single-element ArrayRef
/// with the actual return type otherwise.
llvm::ArrayRef<mlir::Type> getReturnTypes() const;

/// Returns whether the function is returns void.
bool isVoid() const;
/// Does the function type return nothing?
bool hasVoidReturn() const;

/// Returns a clone of this function type with the given argument
/// and result types.
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ bool CIRGenTypes::isFuncTypeConvertible(const FunctionType *ft) {
mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType qft) {
assert(qft.isCanonical());
const FunctionType *ft = cast<FunctionType>(qft.getTypePtr());
// First, check whether we can build the full fucntion type. If the function
// First, check whether we can build the full function type. If the function
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
// the function type.
if (!isFuncTypeConvertible(ft)) {
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ LogicalResult cir::FuncOp::verifyType() {
if (!isa<cir::FuncType>(type))
return emitOpError("requires '" + getFunctionTypeAttrName().str() +
"' attribute of function type");
if (auto rt = type.getReturnTypes();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anywhere that we check that rt.size() == 1 as a part of the verifier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rt.size() > 1 is an impossible situation. FuncType stores only one mlir::Type return type. There is no possible code path where FuncType::getReturnTypes could create an llvm::ArrayRef<mlir::Type> with more than one member, even if FuncType's invariants were broken.

But adding a check for that here doesn't hurt. I'll do that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just super weird/scary for me that FuncType has getReturnTypes here. BUT I consider the verifier for the purposes of sanity-checking anyway... I guess it would fail parsing, but one of the justifications for making the verifier so aggressive in LLVM at checking EVERYTHING is that it can sometimes come from other, side-channel sources like a handwritten text file, or one of the CAPIs.

ALSO-ALSO: If I'm mentally parsing this right: That check on line 424 is an absolute impossibility by the type system, right?

getFunctionType already gives us a cir::FuncType, so an isa will always be true, correct? Or am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getFunctionType already gives us a cir::FuncType, so an isa will always be true, correct?

Correct. I just noticed that myself, as I was looking at this function more closely.

In the incubator, this function has three checks. The first is a tautology, because it is checking if a cir::FuncType isa cir::FuncType. The second, which I had not upstreamed yet, is wrong. It checks that an ellipsis is not the only parameter, but functions with only an ellipsis are allowed in C++. The third check, the one added in this PR, that there isn't an explicit void return type, is redundant, because that is already checked in FuncType::verify.

It doesn't look like FuncOp::verifyType is called from anywhere (even in the incubator), and it doesn't do anything useful. So rather than adding a check for rt.size() < 2 to the function, I will look into getting rid of it entirely.

!rt.empty() && mlir::isa<cir::VoidType>(rt.front()))
return emitOpError("The return type for a function returning void should "
"be empty instead of an explicit !cir.void");
return success();
}

Expand Down
125 changes: 93 additions & 32 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
// CIR Custom Parser/Printer Signatures
//===----------------------------------------------------------------------===//

static mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg);
static void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
mlir::Type &optionalReturnTypes,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg);
static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
mlir::ArrayRef<mlir::Type> params, bool isVarArg);

//===----------------------------------------------------------------------===//
// Get autogenerated stuff
Expand Down Expand Up @@ -282,40 +283,55 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
return get(llvm::to_vector(inputs), results[0], isVarArg());
}

mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
// A special parser is needed for function returning void to handle the missing
// type.
static mlir::ParseResult parseFuncTypeReturn(mlir::AsmParser &p,
mlir::Type &optionalReturnType) {
if (succeeded(p.parseOptionalArrow())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (succeeded(p.parseOptionalArrow())) {
// If `->` is found, it must be followed by the return type.
if (succeeded(p.parseOptionalArrow()))
return p.parseType(optionalReturnType);

// `->` found. It must be followed by the return type.
return p.parseType(optionalReturnType);
}
// Function has `void` return in C++, no return in MLIR.
optionalReturnType = {};
return success();
}

// A special pretty-printer for function returning or not a result.
static void printFuncTypeReturn(mlir::AsmPrinter &p,
mlir::Type optionalReturnType) {
if (optionalReturnType)
p << " -> " << optionalReturnType;
}

static mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
isVarArg = false;
// `(` `)`
if (succeeded(p.parseOptionalRParen()))
if (failed(p.parseLParen()))
return failure();
if (succeeded(p.parseOptionalRParen())) {
// `()` empty argument list
return mlir::success();

// `(` `...` `)`
if (succeeded(p.parseOptionalEllipsis())) {
isVarArg = true;
return p.parseRParen();
}

// type (`,` type)* (`,` `...`)?
mlir::Type type;
if (p.parseType(type))
return mlir::failure();
params.push_back(type);
while (succeeded(p.parseOptionalComma())) {
do {
if (succeeded(p.parseOptionalEllipsis())) {
// `...`, which must be the last thing in the list.
isVarArg = true;
return p.parseRParen();
break;
} else {
mlir::Type argType;
if (failed(p.parseType(argType)))
return failure();
params.push_back(argType);
}
if (p.parseType(type))
return mlir::failure();
params.push_back(type);
}

} while (succeeded(p.parseOptionalComma()));
return p.parseRParen();
}

void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
static void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
p << '(';
llvm::interleaveComma(params, p,
[&p](mlir::Type type) { p.printType(type); });
if (isVarArg) {
Expand All @@ -326,11 +342,56 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
p << ')';
}

// Use a custom parser to handle the optional return and argument types without
// an optional anchor.
static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
mlir::Type &optionalReturnType,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
if (failed(parseFuncTypeArgs(p, params, isVarArg)))
return failure();
return parseFuncTypeReturn(p, optionalReturnType);
}

static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnType,
mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
printFuncTypeArgs(p, params, isVarArg);
printFuncTypeReturn(p, optionalReturnType);
}

/// Get the C-style return type of the function, which is !cir.void if the
/// function returns nothing and the actual return type otherwise.
mlir::Type FuncType::getReturnType() const {
if (hasVoidReturn())
return cir::VoidType::get(getContext());
return getOptionalReturnType();
}

/// Get the MLIR-style return type of the function, which is an empty
/// ArrayRef if the function returns nothing and a single-element ArrayRef
/// with the actual return type otherwise.
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
if (hasVoidReturn())
return {};
// Can't use getOptionalReturnType() here because llvm::ArrayRef hold a
// pointer to its elements and doesn't do lifetime extension. That would
// result in returning a pointer to a temporary that has gone out of scope.
return getImpl()->optionalReturnType;
}

bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
// Does the fuction type return nothing?
bool FuncType::hasVoidReturn() const { return !getOptionalReturnType(); }

mlir::LogicalResult
FuncType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::ArrayRef<mlir::Type> argTypes, mlir::Type returnType,
bool isVarArg) {
if (returnType && mlir::isa<cir::VoidType>(returnType)) {
emitError() << "!cir.func cannot have an explicit 'void' return type";
return mlir::failure();
}
return mlir::success();
}

//===----------------------------------------------------------------------===//
// BoolType
Expand Down
8 changes: 4 additions & 4 deletions clang/test/CIR/IR/func.cir
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

module {
// void empty() { }
cir.func @empty() -> !cir.void {
cir.func @empty() {
cir.return
}
// CHECK: cir.func @empty() -> !cir.void {
// CHECK: cir.func @empty() {
// CHECK: cir.return
// CHECK: }

// void voidret() { return; }
cir.func @voidret() -> !cir.void {
cir.func @voidret() {
cir.return
}
// CHECK: cir.func @voidret() -> !cir.void {
// CHECK: cir.func @voidret() {
// CHECK: cir.return
// CHECK: }

Expand Down
12 changes: 6 additions & 6 deletions clang/test/CIR/IR/global.cir
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
cir.global @dp : !cir.ptr<!cir.double>
cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
cir.global @fp : !cir.ptr<!cir.func<()>>
cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
}

// CHECK: cir.global @c : !cir.int<s, 8>
Expand Down Expand Up @@ -64,6 +64,6 @@ module attributes {cir.triple = "x86_64-unknown-linux-gnu"} {
// CHECK: cir.global @ip = #cir.ptr<null> : !cir.ptr<!cir.int<s, 32>>
// CHECK: cir.global @dp : !cir.ptr<!cir.double>
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>
// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
4 changes: 2 additions & 2 deletions clang/test/CIR/func-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o - | FileCheck %s

void empty() { }
// CHECK: cir.func @empty() -> !cir.void {
// CHECK: cir.func @empty() {
// CHECK: cir.return
// CHECK: }

void voidret() { return; }
// CHECK: cir.func @voidret() -> !cir.void {
// CHECK: cir.func @voidret() {
// CHECK: cir.return
// CHECK: }

Expand Down
6 changes: 3 additions & 3 deletions clang/test/CIR/global-var-simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ char **cpp;
// CHECK: cir.global @cpp : !cir.ptr<!cir.ptr<!cir.int<s, 8>>>

void (*fp)();
// CHECK: cir.global @fp : !cir.ptr<!cir.func<!cir.void ()>>
// CHECK: cir.global @fp : !cir.ptr<!cir.func<()>>

int (*fpii)(int) = 0;
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<!cir.int<s, 32> (!cir.int<s, 32>)>>
// CHECK: cir.global @fpii = #cir.ptr<null> : !cir.ptr<!cir.func<(!cir.int<s, 32>) -> !cir.int<s, 32>>>

void (*fpvar)(int, ...);
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<!cir.void (!cir.int<s, 32>, ...)>>
// CHECK: cir.global @fpvar : !cir.ptr<!cir.func<(!cir.int<s, 32>, ...)>>
Loading