diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td index e91537186df59..34df9af7fc06d 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td @@ -44,6 +44,7 @@ def CIR_Dialect : Dialect { static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; } static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; } static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; } + static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; } void registerAttributes(); void registerTypes(); diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 777b49434f119..5f5fab6f12300 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2728,7 +2728,7 @@ def CIR_LLVMIntrinsicCallOp : CIR_Op<"call_llvm_intrinsic"> { } //===----------------------------------------------------------------------===// -// CallOp +// CallOp and TryCallOp //===----------------------------------------------------------------------===// def CIR_SideEffect : CIR_I32EnumAttr< @@ -2855,6 +2855,96 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> { ]; } +def CIR_TryCallOp : CIR_CallOpBase<"try_call",[ + Terminator +]> { + let summary = "try_call operation"; + let description = [{ + Similar to `cir.call` but requires two destination blocks, + one which is used if the call returns without throwing an + exception (the "normal" destination) and another which is used + if an exception is thrown (the "unwind" destination). + + This operation is used only after the CFG flatterning pass. + + Example: + + ```mlir + // Before CFG flattening + cir.try { + %call = cir.call @division(%a, %b) : () -> !s32i + cir.yield + } catch all { + cir.yield + } + + // After CFG flattening + %call = cir.try_call @division(%a, %b) ^normalDest, ^unwindDest + : (f32, f32) -> f32 + ^normalDest: + cir.br ^afterTryBlock + ^unwindDest: + %exception_ptr, %type_id = cir.eh.inflight_exception + cir.br ^catchHandlerBlock(%exception_ptr : !cir.ptr) + ^catchHandlerBlock: + ... + ``` + }]; + + let arguments = commonArgs; + let results = (outs Optional:$result); + let successors = (successor + AnySuccessor:$normalDest, + AnySuccessor:$unwindDest + ); + + let skipDefaultBuilders = 1; + let hasLLVMLowering = false; + + let builders = [ + OpBuilder<(ins "mlir::SymbolRefAttr":$callee, + "mlir::Type":$resType, + "mlir::Block *":$normalDest, + "mlir::Block *":$unwindDest, + CArg<"mlir::ValueRange", "{}">:$callOperands, + CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{ + $_state.addOperands(callOperands); + + if (callee) + $_state.addAttribute("callee", callee); + if (resType && !isa(resType)) + $_state.addTypes(resType); + + $_state.addAttribute("side_effect", + SideEffectAttr::get($_builder.getContext(), sideEffect)); + + // Handle branches + $_state.addSuccessors(normalDest); + $_state.addSuccessors(unwindDest); + }]>, + OpBuilder<(ins "mlir::Value":$ind_target, + "FuncType":$fn_type, + "mlir::Block *":$normalDest, + "mlir::Block *":$unwindDest, + CArg<"mlir::ValueRange", "{}">:$callOperands, + CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{ + ::llvm::SmallVector finalCallOperands({ind_target}); + finalCallOperands.append(callOperands.begin(), callOperands.end()); + $_state.addOperands(finalCallOperands); + + if (!fn_type.hasVoidReturn()) + $_state.addTypes(fn_type.getReturnType()); + + $_state.addAttribute("side_effect", + SideEffectAttr::get($_builder.getContext(), sideEffect)); + + // Handle branches + $_state.addSuccessors(normalDest); + $_state.addSuccessors(unwindDest); + }]> + ]; +} + //===----------------------------------------------------------------------===// // AwaitOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index f1bacff7fc691..d505ca141d383 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -721,8 +721,28 @@ unsigned cir::CallOp::getNumArgOperands() { return this->getOperation()->getNumOperands(); } +static mlir::ParseResult +parseTryCallDestinations(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::Block *normalDestSuccessor; + if (parser.parseSuccessor(normalDestSuccessor)) + return mlir::failure(); + + if (parser.parseComma()) + return mlir::failure(); + + mlir::Block *unwindDestSuccessor; + if (parser.parseSuccessor(unwindDestSuccessor)) + return mlir::failure(); + + result.addSuccessors(normalDestSuccessor); + result.addSuccessors(unwindDestSuccessor); + return mlir::success(); +} + static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, - mlir::OperationState &result) { + mlir::OperationState &result, + bool hasDestinationBlocks = false) { llvm::SmallVector ops; llvm::SMLoc opsLoc; mlir::FlatSymbolRefAttr calleeAttr; @@ -749,6 +769,11 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser, if (parser.parseRParen()) return mlir::failure(); + if (hasDestinationBlocks && + parseTryCallDestinations(parser, result).failed()) { + return ::mlir::failure(); + } + if (parser.parseOptionalKeyword("nothrow").succeeded()) result.addAttribute(CIRDialect::getNoThrowAttrName(), mlir::UnitAttr::get(parser.getContext())); @@ -788,7 +813,9 @@ static void printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym, mlir::Value indirectCallee, mlir::OpAsmPrinter &printer, bool isNothrow, - cir::SideEffect sideEffect) { + cir::SideEffect sideEffect, + mlir::Block *normalDest = nullptr, + mlir::Block *unwindDest = nullptr) { printer << ' '; auto callLikeOp = mlir::cast(op); @@ -802,8 +829,18 @@ static void printCallCommon(mlir::Operation *op, assert(indirectCallee); printer << indirectCallee; } + printer << "(" << ops << ")"; + if (normalDest) { + assert(unwindDest && "expected two successors"); + auto tryCall = cast(op); + printer << ' ' << tryCall.getNormalDest(); + printer << ","; + printer << ' '; + printer << tryCall.getUnwindDest(); + } + if (isNothrow) printer << " nothrow"; @@ -813,11 +850,11 @@ static void printCallCommon(mlir::Operation *op, printer << ")"; } - printer.printOptionalAttrDict(op->getAttrs(), - {CIRDialect::getCalleeAttrName(), - CIRDialect::getNoThrowAttrName(), - CIRDialect::getSideEffectAttrName()}); - + llvm::SmallVector<::llvm::StringRef> elidedAttrs = { + CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(), + CIRDialect::getSideEffectAttrName(), + CIRDialect::getOperandSegmentSizesAttrName()}; + printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); printer << " : "; printer.printFunctionalType(op->getOperands().getTypes(), op->getResultTypes()); @@ -898,6 +935,59 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return verifyCallCommInSymbolUses(*this, symbolTable); } +//===----------------------------------------------------------------------===// +// TryCallOp +//===----------------------------------------------------------------------===// + +mlir::OperandRange cir::TryCallOp::getArgOperands() { + if (isIndirect()) + return getArgs().drop_front(1); + return getArgs(); +} + +mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() { + mlir::MutableOperandRange args = getArgsMutable(); + if (isIndirect()) + return args.slice(1, args.size() - 1); + return args; +} + +mlir::Value cir::TryCallOp::getIndirectCall() { + assert(isIndirect()); + return getOperand(0); +} + +/// Return the operand at index 'i'. +Value cir::TryCallOp::getArgOperand(unsigned i) { + if (isIndirect()) + ++i; + return getOperand(i); +} + +/// Return the number of operands. +unsigned cir::TryCallOp::getNumArgOperands() { + if (isIndirect()) + return this->getOperation()->getNumOperands() - 1; + return this->getOperation()->getNumOperands(); +} + +LogicalResult +cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + return verifyCallCommInSymbolUses(*this, symbolTable); +} + +mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true); +} + +void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) { + mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr; + cir::SideEffect sideEffect = getSideEffect(); + printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(), + sideEffect, getNormalDest(), getUnwindDest()); +} + //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir new file mode 100644 index 0000000000000..39db43aee40c1 --- /dev/null +++ b/clang/test/CIR/IR/try-call.cir @@ -0,0 +1,35 @@ +// RUN: cir-opt %s --verify-roundtrip | FileCheck %s + +!s32i = !cir.int + +module { + +cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i + +cir.func @flatten_structure_with_try_call_op() { + %a = cir.const #cir.int<1> : !s32i + %b = cir.const #cir.int<2> : !s32i + %3 = cir.try_call @division(%a, %b) ^normal, ^unwind : (!s32i, !s32i) -> !s32i + ^normal: + cir.br ^end + ^unwind: + cir.br ^end + ^end: + cir.return +} + +// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i + +// CHECK: cir.func @flatten_structure_with_try_call_op() { +// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i +// CHECK-NEXT: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i +// CHECK-NEXT: %[[CALL:.*]] = cir.try_call @division(%[[CONST_1]], %[[CONST_2]]) ^[[NORMAL:.*]], ^[[UNWIND:.*]] : (!s32i, !s32i) -> !s32i +// CHECK-NEXT: ^[[NORMAL]]: +// CHECK-NEXT: cir.br ^[[END:.*]] +// CHECK-NEXT: ^[[UNWIND]]: +// CHECK-NEXT: cir.br ^[[END:.*]] +// CHECK-NEXT: ^[[END]]: +// CHECK-NEXT: cir.return +// CHECK-NEXT: } + +}