Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }

void registerAttributes();
void registerTypes();
Expand Down
92 changes: 91 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2728,7 +2728,7 @@ def CIR_LLVMIntrinsicCallOp : CIR_Op<"call_llvm_intrinsic"> {
}

//===----------------------------------------------------------------------===//
// CallOp
// CallOp and TryCallOp
//===----------------------------------------------------------------------===//

def CIR_SideEffect : CIR_I32EnumAttr<
Expand Down Expand Up @@ -2855,6 +2855,96 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
];
}

def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
Terminator
]> {
let summary = "try_call operation";
let description = [{
Similar to `cir.call` but requires two destination blocks,
one which is used if the call returns without throwing an
exception (the "normal" destination) and another which is used
if an exception is thrown (the "unwind" destination).

This operation is used only after the CFG flatterning pass.

Example:

```mlir
// Before CFG flattening
cir.try {
%call = cir.call @division(%a, %b) : () -> !s32i
cir.yield
} catch all {
cir.yield
}

// After CFG flattening
%call = cir.try_call @division(%a, %b) ^normalDest, ^unwindDest
: (f32, f32) -> f32
^normalDest:
cir.br ^afterTryBlock
^unwindDest:
%exception_ptr, %type_id = cir.eh.inflight_exception
cir.br ^catchHandlerBlock(%exception_ptr : !cir.ptr<!void>)
^catchHandlerBlock:
...
```
}];

let arguments = commonArgs;
let results = (outs Optional<CIR_AnyType>:$result);
let successors = (successor
AnySuccessor:$normalDest,
AnySuccessor:$unwindDest
);

let skipDefaultBuilders = 1;
let hasLLVMLowering = false;

let builders = [
OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
"mlir::Type":$resType,
"mlir::Block *":$normalDest,
"mlir::Block *":$unwindDest,
CArg<"mlir::ValueRange", "{}">:$callOperands,
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
$_state.addOperands(callOperands);

if (callee)
$_state.addAttribute("callee", callee);
if (resType && !isa<VoidType>(resType))
$_state.addTypes(resType);

$_state.addAttribute("side_effect",
SideEffectAttr::get($_builder.getContext(), sideEffect));

// Handle branches
$_state.addSuccessors(normalDest);
$_state.addSuccessors(unwindDest);
}]>,
OpBuilder<(ins "mlir::Value":$ind_target,
"FuncType":$fn_type,
"mlir::Block *":$normalDest,
"mlir::Block *":$unwindDest,
CArg<"mlir::ValueRange", "{}">:$callOperands,
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
finalCallOperands.append(callOperands.begin(), callOperands.end());
$_state.addOperands(finalCallOperands);

if (!fn_type.hasVoidReturn())
$_state.addTypes(fn_type.getReturnType());

$_state.addAttribute("side_effect",
SideEffectAttr::get($_builder.getContext(), sideEffect));

// Handle branches
$_state.addSuccessors(normalDest);
$_state.addSuccessors(unwindDest);
}]>
];
}

//===----------------------------------------------------------------------===//
// AwaitOp
//===----------------------------------------------------------------------===//
Expand Down
104 changes: 97 additions & 7 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -721,8 +721,28 @@ unsigned cir::CallOp::getNumArgOperands() {
return this->getOperation()->getNumOperands();
}

static mlir::ParseResult
parseTryCallDestinations(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::Block *normalDestSuccessor;
if (parser.parseSuccessor(normalDestSuccessor))
return mlir::failure();

if (parser.parseComma())
return mlir::failure();

mlir::Block *unwindDestSuccessor;
if (parser.parseSuccessor(unwindDestSuccessor))
return mlir::failure();

result.addSuccessors(normalDestSuccessor);
result.addSuccessors(unwindDestSuccessor);
return mlir::success();
}

static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::OperationState &result,
bool hasDestinationBlocks = false) {
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
llvm::SMLoc opsLoc;
mlir::FlatSymbolRefAttr calleeAttr;
Expand All @@ -749,6 +769,11 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
if (parser.parseRParen())
return mlir::failure();

if (hasDestinationBlocks &&
parseTryCallDestinations(parser, result).failed()) {
return ::mlir::failure();
}

if (parser.parseOptionalKeyword("nothrow").succeeded())
result.addAttribute(CIRDialect::getNoThrowAttrName(),
mlir::UnitAttr::get(parser.getContext()));
Expand Down Expand Up @@ -788,7 +813,9 @@ static void printCallCommon(mlir::Operation *op,
mlir::FlatSymbolRefAttr calleeSym,
mlir::Value indirectCallee,
mlir::OpAsmPrinter &printer, bool isNothrow,
cir::SideEffect sideEffect) {
cir::SideEffect sideEffect,
mlir::Block *normalDest = nullptr,
mlir::Block *unwindDest = nullptr) {
printer << ' ';

auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
Expand All @@ -802,8 +829,18 @@ static void printCallCommon(mlir::Operation *op,
assert(indirectCallee);
printer << indirectCallee;
}

printer << "(" << ops << ")";

if (normalDest) {
assert(unwindDest && "expected two successors");
auto tryCall = cast<cir::TryCallOp>(op);
printer << ' ' << tryCall.getNormalDest();
printer << ",";
printer << ' ';
printer << tryCall.getUnwindDest();
}

if (isNothrow)
printer << " nothrow";

Expand All @@ -813,11 +850,11 @@ static void printCallCommon(mlir::Operation *op,
printer << ")";
}

printer.printOptionalAttrDict(op->getAttrs(),
{CIRDialect::getCalleeAttrName(),
CIRDialect::getNoThrowAttrName(),
CIRDialect::getSideEffectAttrName()});

llvm::SmallVector<::llvm::StringRef> elidedAttrs = {
CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
CIRDialect::getSideEffectAttrName(),
CIRDialect::getOperandSegmentSizesAttrName()};
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
printer << " : ";
printer.printFunctionalType(op->getOperands().getTypes(),
op->getResultTypes());
Expand Down Expand Up @@ -898,6 +935,59 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyCallCommInSymbolUses(*this, symbolTable);
}

//===----------------------------------------------------------------------===//
// TryCallOp
//===----------------------------------------------------------------------===//

mlir::OperandRange cir::TryCallOp::getArgOperands() {
if (isIndirect())
return getArgs().drop_front(1);
return getArgs();
}

mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
mlir::MutableOperandRange args = getArgsMutable();
if (isIndirect())
return args.slice(1, args.size() - 1);
return args;
}

mlir::Value cir::TryCallOp::getIndirectCall() {
assert(isIndirect());
return getOperand(0);
}

/// Return the operand at index 'i'.
Value cir::TryCallOp::getArgOperand(unsigned i) {
if (isIndirect())
++i;
return getOperand(i);
}

/// Return the number of operands.
unsigned cir::TryCallOp::getNumArgOperands() {
if (isIndirect())
return this->getOperation()->getNumOperands() - 1;
return this->getOperation()->getNumOperands();
}

LogicalResult
cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyCallCommInSymbolUses(*this, symbolTable);
}

mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
}

void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
cir::SideEffect sideEffect = getSideEffect();
printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
sideEffect, getNormalDest(), getUnwindDest());
}

//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 35 additions & 0 deletions clang/test/CIR/IR/try-call.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: cir-opt %s --verify-roundtrip | FileCheck %s

!s32i = !cir.int<s, 32>

module {

cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i

cir.func @flatten_structure_with_try_call_op() {
%a = cir.const #cir.int<1> : !s32i
%b = cir.const #cir.int<2> : !s32i
%3 = cir.try_call @division(%a, %b) ^normal, ^unwind : (!s32i, !s32i) -> !s32i
^normal:
cir.br ^end
^unwind:
cir.br ^end
^end:
cir.return
}

// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i

// CHECK: cir.func @flatten_structure_with_try_call_op() {
// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
// CHECK-NEXT: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
// CHECK-NEXT: %[[CALL:.*]] = cir.try_call @division(%[[CONST_1]], %[[CONST_2]]) ^[[NORMAL:.*]], ^[[UNWIND:.*]] : (!s32i, !s32i) -> !s32i
// CHECK-NEXT: ^[[NORMAL]]:
// CHECK-NEXT: cir.br ^[[END:.*]]
// CHECK-NEXT: ^[[UNWIND]]:
// CHECK-NEXT: cir.br ^[[END:.*]]
// CHECK-NEXT: ^[[END]]:
// CHECK-NEXT: cir.return
// CHECK-NEXT: }

}