Skip to content

Commit 2e655c2

Browse files
authored
[CIR] Upstream TryCallOp (#165303)
Upstream TryCall Op as a prerequisite for Try Catch work Issue #154992
1 parent ad605bd commit 2e655c2

File tree

4 files changed

+224
-8
lines changed

4 files changed

+224
-8
lines changed

clang/include/clang/CIR/Dialect/IR/CIRDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
4444
static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
4545
static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
4646
static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
47+
static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
4748

4849
void registerAttributes();
4950
void registerTypes();

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2728,7 +2728,7 @@ def CIR_LLVMIntrinsicCallOp : CIR_Op<"call_llvm_intrinsic"> {
27282728
}
27292729

27302730
//===----------------------------------------------------------------------===//
2731-
// CallOp
2731+
// CallOp and TryCallOp
27322732
//===----------------------------------------------------------------------===//
27332733

27342734
def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2855,6 +2855,96 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
28552855
];
28562856
}
28572857

2858+
def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
2859+
Terminator
2860+
]> {
2861+
let summary = "try_call operation";
2862+
let description = [{
2863+
Similar to `cir.call` but requires two destination blocks,
2864+
one which is used if the call returns without throwing an
2865+
exception (the "normal" destination) and another which is used
2866+
if an exception is thrown (the "unwind" destination).
2867+
2868+
This operation is used only after the CFG flatterning pass.
2869+
2870+
Example:
2871+
2872+
```mlir
2873+
// Before CFG flattening
2874+
cir.try {
2875+
%call = cir.call @division(%a, %b) : () -> !s32i
2876+
cir.yield
2877+
} catch all {
2878+
cir.yield
2879+
}
2880+
2881+
// After CFG flattening
2882+
%call = cir.try_call @division(%a, %b) ^normalDest, ^unwindDest
2883+
: (f32, f32) -> f32
2884+
^normalDest:
2885+
cir.br ^afterTryBlock
2886+
^unwindDest:
2887+
%exception_ptr, %type_id = cir.eh.inflight_exception
2888+
cir.br ^catchHandlerBlock(%exception_ptr : !cir.ptr<!void>)
2889+
^catchHandlerBlock:
2890+
...
2891+
```
2892+
}];
2893+
2894+
let arguments = commonArgs;
2895+
let results = (outs Optional<CIR_AnyType>:$result);
2896+
let successors = (successor
2897+
AnySuccessor:$normalDest,
2898+
AnySuccessor:$unwindDest
2899+
);
2900+
2901+
let skipDefaultBuilders = 1;
2902+
let hasLLVMLowering = false;
2903+
2904+
let builders = [
2905+
OpBuilder<(ins "mlir::SymbolRefAttr":$callee,
2906+
"mlir::Type":$resType,
2907+
"mlir::Block *":$normalDest,
2908+
"mlir::Block *":$unwindDest,
2909+
CArg<"mlir::ValueRange", "{}">:$callOperands,
2910+
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
2911+
$_state.addOperands(callOperands);
2912+
2913+
if (callee)
2914+
$_state.addAttribute("callee", callee);
2915+
if (resType && !isa<VoidType>(resType))
2916+
$_state.addTypes(resType);
2917+
2918+
$_state.addAttribute("side_effect",
2919+
SideEffectAttr::get($_builder.getContext(), sideEffect));
2920+
2921+
// Handle branches
2922+
$_state.addSuccessors(normalDest);
2923+
$_state.addSuccessors(unwindDest);
2924+
}]>,
2925+
OpBuilder<(ins "mlir::Value":$ind_target,
2926+
"FuncType":$fn_type,
2927+
"mlir::Block *":$normalDest,
2928+
"mlir::Block *":$unwindDest,
2929+
CArg<"mlir::ValueRange", "{}">:$callOperands,
2930+
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
2931+
::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
2932+
finalCallOperands.append(callOperands.begin(), callOperands.end());
2933+
$_state.addOperands(finalCallOperands);
2934+
2935+
if (!fn_type.hasVoidReturn())
2936+
$_state.addTypes(fn_type.getReturnType());
2937+
2938+
$_state.addAttribute("side_effect",
2939+
SideEffectAttr::get($_builder.getContext(), sideEffect));
2940+
2941+
// Handle branches
2942+
$_state.addSuccessors(normalDest);
2943+
$_state.addSuccessors(unwindDest);
2944+
}]>
2945+
];
2946+
}
2947+
28582948
//===----------------------------------------------------------------------===//
28592949
// AwaitOp
28602950
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,28 @@ unsigned cir::CallOp::getNumArgOperands() {
721721
return this->getOperation()->getNumOperands();
722722
}
723723

724+
static mlir::ParseResult
725+
parseTryCallDestinations(mlir::OpAsmParser &parser,
726+
mlir::OperationState &result) {
727+
mlir::Block *normalDestSuccessor;
728+
if (parser.parseSuccessor(normalDestSuccessor))
729+
return mlir::failure();
730+
731+
if (parser.parseComma())
732+
return mlir::failure();
733+
734+
mlir::Block *unwindDestSuccessor;
735+
if (parser.parseSuccessor(unwindDestSuccessor))
736+
return mlir::failure();
737+
738+
result.addSuccessors(normalDestSuccessor);
739+
result.addSuccessors(unwindDestSuccessor);
740+
return mlir::success();
741+
}
742+
724743
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
725-
mlir::OperationState &result) {
744+
mlir::OperationState &result,
745+
bool hasDestinationBlocks = false) {
726746
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
727747
llvm::SMLoc opsLoc;
728748
mlir::FlatSymbolRefAttr calleeAttr;
@@ -749,6 +769,11 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
749769
if (parser.parseRParen())
750770
return mlir::failure();
751771

772+
if (hasDestinationBlocks &&
773+
parseTryCallDestinations(parser, result).failed()) {
774+
return ::mlir::failure();
775+
}
776+
752777
if (parser.parseOptionalKeyword("nothrow").succeeded())
753778
result.addAttribute(CIRDialect::getNoThrowAttrName(),
754779
mlir::UnitAttr::get(parser.getContext()));
@@ -788,7 +813,9 @@ static void printCallCommon(mlir::Operation *op,
788813
mlir::FlatSymbolRefAttr calleeSym,
789814
mlir::Value indirectCallee,
790815
mlir::OpAsmPrinter &printer, bool isNothrow,
791-
cir::SideEffect sideEffect) {
816+
cir::SideEffect sideEffect,
817+
mlir::Block *normalDest = nullptr,
818+
mlir::Block *unwindDest = nullptr) {
792819
printer << ' ';
793820

794821
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -802,8 +829,18 @@ static void printCallCommon(mlir::Operation *op,
802829
assert(indirectCallee);
803830
printer << indirectCallee;
804831
}
832+
805833
printer << "(" << ops << ")";
806834

835+
if (normalDest) {
836+
assert(unwindDest && "expected two successors");
837+
auto tryCall = cast<cir::TryCallOp>(op);
838+
printer << ' ' << tryCall.getNormalDest();
839+
printer << ",";
840+
printer << ' ';
841+
printer << tryCall.getUnwindDest();
842+
}
843+
807844
if (isNothrow)
808845
printer << " nothrow";
809846

@@ -813,11 +850,11 @@ static void printCallCommon(mlir::Operation *op,
813850
printer << ")";
814851
}
815852

816-
printer.printOptionalAttrDict(op->getAttrs(),
817-
{CIRDialect::getCalleeAttrName(),
818-
CIRDialect::getNoThrowAttrName(),
819-
CIRDialect::getSideEffectAttrName()});
820-
853+
llvm::SmallVector<::llvm::StringRef> elidedAttrs = {
854+
CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
855+
CIRDialect::getSideEffectAttrName(),
856+
CIRDialect::getOperandSegmentSizesAttrName()};
857+
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
821858
printer << " : ";
822859
printer.printFunctionalType(op->getOperands().getTypes(),
823860
op->getResultTypes());
@@ -898,6 +935,59 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
898935
return verifyCallCommInSymbolUses(*this, symbolTable);
899936
}
900937

938+
//===----------------------------------------------------------------------===//
939+
// TryCallOp
940+
//===----------------------------------------------------------------------===//
941+
942+
mlir::OperandRange cir::TryCallOp::getArgOperands() {
943+
if (isIndirect())
944+
return getArgs().drop_front(1);
945+
return getArgs();
946+
}
947+
948+
mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
949+
mlir::MutableOperandRange args = getArgsMutable();
950+
if (isIndirect())
951+
return args.slice(1, args.size() - 1);
952+
return args;
953+
}
954+
955+
mlir::Value cir::TryCallOp::getIndirectCall() {
956+
assert(isIndirect());
957+
return getOperand(0);
958+
}
959+
960+
/// Return the operand at index 'i'.
961+
Value cir::TryCallOp::getArgOperand(unsigned i) {
962+
if (isIndirect())
963+
++i;
964+
return getOperand(i);
965+
}
966+
967+
/// Return the number of operands.
968+
unsigned cir::TryCallOp::getNumArgOperands() {
969+
if (isIndirect())
970+
return this->getOperation()->getNumOperands() - 1;
971+
return this->getOperation()->getNumOperands();
972+
}
973+
974+
LogicalResult
975+
cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
976+
return verifyCallCommInSymbolUses(*this, symbolTable);
977+
}
978+
979+
mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
980+
mlir::OperationState &result) {
981+
return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
982+
}
983+
984+
void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
985+
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
986+
cir::SideEffect sideEffect = getSideEffect();
987+
printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
988+
sideEffect, getNormalDest(), getUnwindDest());
989+
}
990+
901991
//===----------------------------------------------------------------------===//
902992
// ReturnOp
903993
//===----------------------------------------------------------------------===//

clang/test/CIR/IR/try-call.cir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
7+
cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i
8+
9+
cir.func @flatten_structure_with_try_call_op() {
10+
%a = cir.const #cir.int<1> : !s32i
11+
%b = cir.const #cir.int<2> : !s32i
12+
%3 = cir.try_call @division(%a, %b) ^normal, ^unwind : (!s32i, !s32i) -> !s32i
13+
^normal:
14+
cir.br ^end
15+
^unwind:
16+
cir.br ^end
17+
^end:
18+
cir.return
19+
}
20+
21+
// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i
22+
23+
// CHECK: cir.func @flatten_structure_with_try_call_op() {
24+
// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
25+
// CHECK-NEXT: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
26+
// CHECK-NEXT: %[[CALL:.*]] = cir.try_call @division(%[[CONST_1]], %[[CONST_2]]) ^[[NORMAL:.*]], ^[[UNWIND:.*]] : (!s32i, !s32i) -> !s32i
27+
// CHECK-NEXT: ^[[NORMAL]]:
28+
// CHECK-NEXT: cir.br ^[[END:.*]]
29+
// CHECK-NEXT: ^[[UNWIND]]:
30+
// CHECK-NEXT: cir.br ^[[END:.*]]
31+
// CHECK-NEXT: ^[[END]]:
32+
// CHECK-NEXT: cir.return
33+
// CHECK-NEXT: }
34+
35+
}

0 commit comments

Comments
 (0)