-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[CIR] Upstream TryCallOp #165303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[CIR] Upstream TryCallOp #165303
Conversation
|
@llvm/pr-subscribers-clang Author: Amr Hesham (AmrDeveloper) ChangesUpstream TryCall Op as a prerequisite for Try Catch work Issue #154992 Full diff: https://github.com/llvm/llvm-project/pull/165303.diff 5 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index e91537186df59..34df9af7fc06d 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
+ static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
void registerAttributes();
void registerTypes();
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 2b361ed0982c6..8f3e25b3c9737 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2580,7 +2580,7 @@ def CIR_FuncOp : CIR_Op<"func", [
}
//===----------------------------------------------------------------------===//
-// CallOp
+// CallOp and TryCallOp
//===----------------------------------------------------------------------===//
def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2707,6 +2707,98 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
];
}
+def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
+ DeclareOpInterfaceMethods<BranchOpInterface>,
+ Terminator, AttrSizedOperandSegments
+]> {
+ let summary = "try_call operation";
+
+ let description = [{
+ Mostly similar to cir.call but requires two destination
+ branches, one for handling exceptions in case its thrown and
+ the other one to follow on regular control-flow.
+
+ Example:
+
+ ```mlir
+ // Direct call
+ %result = cir.try_call @division(%a, %b) ^continue, ^landing_pad
+ : (f32, f32) -> f32
+ ```
+ }];
+
+ let arguments = !con((ins
+ Variadic<CIR_AnyType>:$contOperands,
+ Variadic<CIR_AnyType>:$landingPadOperands
+ ), commonArgs);
+
+ let results = (outs Optional<CIR_AnyType>:$result);
+ let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad);
+
+ let skipDefaultBuilders = 1;
+
+ let builders = [
+ OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
+ "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+ CArg<"mlir::ValueRange", "{}">:$operands,
+ CArg<"mlir::ValueRange", "{}">:$contOperands,
+ CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
+ CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+ $_state.addOperands(operands);
+ if (callee)
+ $_state.addAttribute("callee", callee);
+ if (resType && !isa<VoidType>(resType))
+ $_state.addTypes(resType);
+
+ $_state.addAttribute("side_effect",
+ SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+ // Handle branches
+ $_state.addOperands(contOperands);
+ $_state.addOperands(landingPadOperands);
+ // The TryCall ODS layout is: cont, landing_pad, operands.
+ llvm::copy(::llvm::ArrayRef<int32_t>({
+ static_cast<int32_t>(contOperands.size()),
+ static_cast<int32_t>(landingPadOperands.size()),
+ static_cast<int32_t>(operands.size())
+ }),
+ odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+ $_state.addSuccessors(cont);
+ $_state.addSuccessors(landing_pad);
+ }]>,
+ OpBuilder<(ins "mlir::Value":$ind_target,
+ "FuncType":$fn_type,
+ "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+ CArg<"mlir::ValueRange", "{}">:$operands,
+ CArg<"mlir::ValueRange", "{}">:$contOperands,
+ CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
+ CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+ ::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
+ finalCallOperands.append(operands.begin(), operands.end());
+ $_state.addOperands(finalCallOperands);
+
+ if (!fn_type.hasVoidReturn())
+ $_state.addTypes(fn_type.getReturnType());
+
+ $_state.addAttribute("side_effect",
+ SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+ // Handle branches
+ $_state.addOperands(contOperands);
+ $_state.addOperands(landingPadOperands);
+ // The TryCall ODS layout is: cont, landing_pad, operands.
+ llvm::copy(::llvm::ArrayRef<int32_t>({
+ static_cast<int32_t>(contOperands.size()),
+ static_cast<int32_t>(landingPadOperands.size()),
+ static_cast<int32_t>(finalCallOperands.size())
+ }),
+ odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+ $_state.addSuccessors(cont);
+ $_state.addSuccessors(landing_pad);
+ }]>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 2d2ef422bfaef..11074af3ef127 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -701,13 +701,78 @@ unsigned cir::CallOp::getNumArgOperands() {
return this->getOperation()->getNumOperands();
}
+static mlir::ParseResult
+parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+ &continueOperands,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+ &landingPadOperands,
+ llvm::SmallVectorImpl<mlir::Type> &continueTypes,
+ llvm::SmallVectorImpl<mlir::Type> &landingPadTypes,
+ llvm::SMLoc &continueOperandsLoc,
+ llvm::SMLoc &landingPadOperandsLoc) {
+ mlir::Block *continueSuccessor = nullptr;
+ mlir::Block *landingPadSuccessor = nullptr;
+
+ if (parser.parseSuccessor(continueSuccessor))
+ return mlir::failure();
+
+ if (mlir::succeeded(parser.parseOptionalLParen())) {
+ continueOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(continueOperands))
+ return mlir::failure();
+ if (parser.parseColon())
+ return mlir::failure();
+
+ if (parser.parseTypeList(continueTypes))
+ return mlir::failure();
+ if (parser.parseRParen())
+ return mlir::failure();
+ }
+
+ if (parser.parseComma())
+ return mlir::failure();
+
+ if (parser.parseSuccessor(landingPadSuccessor))
+ return mlir::failure();
+
+ if (mlir::succeeded(parser.parseOptionalLParen())) {
+ landingPadOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(landingPadOperands))
+ return mlir::failure();
+ if (parser.parseColon())
+ return mlir::failure();
+
+ if (parser.parseTypeList(landingPadTypes))
+ return mlir::failure();
+ if (parser.parseRParen())
+ return mlir::failure();
+ }
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return mlir::failure();
+
+ result.addSuccessors(continueSuccessor);
+ result.addSuccessors(landingPadSuccessor);
+ return mlir::success();
+}
+
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
- mlir::OperationState &result) {
+ mlir::OperationState &result,
+ bool hasDestinationBlocks = false) {
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
llvm::SMLoc opsLoc;
mlir::FlatSymbolRefAttr calleeAttr;
llvm::ArrayRef<mlir::Type> allResultTypes;
+ // TryCall control flow related
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> continueOperands;
+ llvm::SMLoc continueOperandsLoc;
+ llvm::SmallVector<mlir::Type, 1> continueTypes;
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> landingPadOperands;
+ llvm::SMLoc landingPadOperandsLoc;
+ llvm::SmallVector<mlir::Type, 1> landingPadTypes;
+
// If we cannot parse a string callee, it means this is an indirect call.
if (!parser
.parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
@@ -729,6 +794,14 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
if (parser.parseRParen())
return mlir::failure();
+ if (hasDestinationBlocks &&
+ parseTryCallBranches(parser, result, continueOperands, landingPadOperands,
+ continueTypes, landingPadTypes, continueOperandsLoc,
+ landingPadOperandsLoc)
+ .failed()) {
+ return ::mlir::failure();
+ }
+
if (parser.parseOptionalKeyword("nothrow").succeeded())
result.addAttribute(CIRDialect::getNoThrowAttrName(),
mlir::UnitAttr::get(parser.getContext()));
@@ -761,6 +834,24 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
return mlir::failure();
+ if (hasDestinationBlocks) {
+ // The TryCall ODS layout is: cont, landing_pad, operands.
+ llvm::copy(::llvm::ArrayRef<int32_t>(
+ {static_cast<int32_t>(continueOperands.size()),
+ static_cast<int32_t>(landingPadOperands.size()),
+ static_cast<int32_t>(ops.size())}),
+ result.getOrAddProperties<cir::TryCallOp::Properties>()
+ .operandSegmentSizes.begin());
+
+ if (parser.resolveOperands(continueOperands, continueTypes,
+ continueOperandsLoc, result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(landingPadOperands, landingPadTypes,
+ landingPadOperandsLoc, result.operands))
+ return ::mlir::failure();
+ }
+
return mlir::success();
}
@@ -768,7 +859,9 @@ static void printCallCommon(mlir::Operation *op,
mlir::FlatSymbolRefAttr calleeSym,
mlir::Value indirectCallee,
mlir::OpAsmPrinter &printer, bool isNothrow,
- cir::SideEffect sideEffect) {
+ cir::SideEffect sideEffect,
+ mlir::Block *cont = nullptr,
+ mlir::Block *landingPad = nullptr) {
printer << ' ';
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -782,8 +875,35 @@ static void printCallCommon(mlir::Operation *op,
assert(indirectCallee);
printer << indirectCallee;
}
+
printer << "(" << ops << ")";
+ if (cont) {
+ assert(landingPad && "expected two successors");
+ auto tryCall = dyn_cast<cir::TryCallOp>(op);
+ assert(tryCall && "regular calls do not branch");
+ printer << ' ' << tryCall.getCont();
+ if (!tryCall.getContOperands().empty()) {
+ printer << "(";
+ printer << tryCall.getContOperands();
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << tryCall.getContOperands().getTypes();
+ printer << ")";
+ }
+ printer << ",";
+ printer << ' ';
+ printer << tryCall.getLandingPad();
+ if (!tryCall.getLandingPadOperands().empty()) {
+ printer << "(";
+ printer << tryCall.getLandingPadOperands();
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << tryCall.getLandingPadOperands().getTypes();
+ printer << ")";
+ }
+ }
+
if (isNothrow)
printer << " nothrow";
@@ -793,10 +913,11 @@ static void printCallCommon(mlir::Operation *op,
printer << ")";
}
- printer.printOptionalAttrDict(op->getAttrs(),
- {CIRDialect::getCalleeAttrName(),
- CIRDialect::getNoThrowAttrName(),
- CIRDialect::getSideEffectAttrName()});
+ llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = {
+ CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
+ CIRDialect::getSideEffectAttrName(),
+ CIRDialect::getOperandSegmentSizesAttrName()};
+ printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
printer << " : ";
printer.printFunctionalType(op->getOperands().getTypes(),
@@ -878,6 +999,70 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyCallCommInSymbolUses(*this, symbolTable);
}
+//===----------------------------------------------------------------------===//
+// TryCallOp
+//===----------------------------------------------------------------------===//
+
+mlir::OperandRange cir::TryCallOp::getArgOperands() {
+ if (isIndirect())
+ return getArgs().drop_front(1);
+ return getArgs();
+}
+
+mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
+ mlir::MutableOperandRange args = getArgsMutable();
+ if (isIndirect())
+ return args.slice(1, args.size() - 1);
+ return args;
+}
+
+mlir::Value cir::TryCallOp::getIndirectCall() {
+ assert(isIndirect());
+ return getOperand(0);
+}
+
+/// Return the operand at index 'i'.
+Value cir::TryCallOp::getArgOperand(unsigned i) {
+ if (isIndirect())
+ ++i;
+ return getOperand(i);
+}
+
+/// Return the number of operands.
+unsigned cir::TryCallOp::getNumArgOperands() {
+ if (isIndirect())
+ return this->getOperation()->getNumOperands() - 1;
+ return this->getOperation()->getNumOperands();
+}
+
+LogicalResult
+cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ return verifyCallCommInSymbolUses(*this, symbolTable);
+}
+
+mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
+}
+
+void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
+ mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
+ cir::SideEffect sideEffect = getSideEffect();
+ printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
+ sideEffect, getCont(), getLandingPad());
+}
+
+mlir::SuccessorOperands cir::TryCallOp::getSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ if (index == 0)
+ return SuccessorOperands(getContOperandsMutable());
+ if (index == 1)
+ return SuccessorOperands(getLandingPadOperandsMutable());
+
+ // index == 2
+ return SuccessorOperands(getArgOperandsMutable());
+}
+
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5a6193fa8d840..12f3db01c77d8 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1385,7 +1385,9 @@ static mlir::LogicalResult
rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
- mlir::FlatSymbolRefAttr calleeAttr) {
+ mlir::FlatSymbolRefAttr calleeAttr,
+ mlir::Block *continueBlock = nullptr,
+ mlir::Block *landingPadBlock = nullptr) {
llvm::SmallVector<mlir::Type, 8> llvmResults;
mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes();
auto call = cast<cir::CIRCallOpInterface>(op);
@@ -1414,7 +1416,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
llvmFnTy = converter->convertType<mlir::LLVM::LLVMFunctionType>(
fn.getFunctionType());
assert(llvmFnTy && "Failed to convert function type");
- } else if (auto alias = mlir::cast<mlir::LLVM::AliasOp>(callee)) {
+ } else if (auto alias = mlir::dyn_cast<mlir::LLVM::AliasOp>(callee)) {
// If the callee was an alias. In that case,
// we need to prepend the address of the alias to the operands. The
// way aliases work in the LLVM dialect is a little counter-intuitive.
@@ -1452,17 +1454,21 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
converter->convertType(calleeFuncTy));
}
- assert(!cir::MissingFeatures::opCallLandingPad());
- assert(!cir::MissingFeatures::opCallContinueBlock());
assert(!cir::MissingFeatures::opCallCallConv());
+ if (landingPadBlock) {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
+ op, llvmFnTy, calleeAttr, callOperands, continueBlock,
+ mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
+ return mlir::success();
+ }
+
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
op, llvmFnTy, calleeAttr, callOperands);
if (memoryEffects)
newOp.setMemoryEffectsAttr(memoryEffects);
newOp.setNoUnwind(noUnwind);
newOp.setWillReturn(willReturn);
-
return mlir::success();
}
@@ -1473,6 +1479,14 @@ mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite(
getTypeConverter(), op.getCalleeAttr());
}
+mlir::LogicalResult CIRToLLVMTryCallOpLowering::matchAndRewrite(
+ cir::TryCallOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter,
+ getTypeConverter(), op.getCalleeAttr(),
+ op.getCont(), op.getLandingPad());
+}
+
mlir::LogicalResult CIRToLLVMReturnAddrOpLowering::matchAndRewrite(
cir::ReturnAddrOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir
new file mode 100644
index 0000000000000..6c23d3add15c8
--- /dev/null
+++ b/clang/test/CIR/IR/try-call.cir
@@ -0,0 +1,31 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i
+
+cir.func @flatten_structure_with_try_call_op() {
+ %a = cir.const #cir.int<1> : !s32i
+ %b = cir.const #cir.int<2> : !s32i
+ %3 = cir.try_call @division(%a, %b) ^continue, ^landing_pad : (!s32i, !s32i) -> !s32i
+ ^continue:
+ cir.br ^landing_pad
+ ^landing_pad:
+ cir.return
+}
+
+// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i
+
+// CHECK: cir.func @flatten_structure_with_try_call_op() {
+// CHECK-NEXT: %[[CONST_0:.*]] = cir.const #cir.int<1> : !s32i
+// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<2> : !s32i
+// CHECK-NEXT: %[[CALL:.*]] = cir.try_call @division(%0, %1) ^[[CONTINUE:.*]], ^[[LANDING_PAD:.*]] : (!s32i, !s32i) -> !s32i
+// CHECK-NEXT: ^[[CONTINUE]]:
+// CHECK-NEXT: cir.br ^[[LANDING_PAD]]
+// CHECK-NEXT: ^[[LANDING_PAD]]:
+// CHECK-NEXT: cir.return
+// CHECK-NEXT: }
+
+}
|
|
@llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesUpstream TryCall Op as a prerequisite for Try Catch work Issue #154992 Full diff: https://github.com/llvm/llvm-project/pull/165303.diff 5 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index e91537186df59..34df9af7fc06d 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
+ static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
void registerAttributes();
void registerTypes();
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 2b361ed0982c6..8f3e25b3c9737 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2580,7 +2580,7 @@ def CIR_FuncOp : CIR_Op<"func", [
}
//===----------------------------------------------------------------------===//
-// CallOp
+// CallOp and TryCallOp
//===----------------------------------------------------------------------===//
def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2707,6 +2707,98 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
];
}
+def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
+ DeclareOpInterfaceMethods<BranchOpInterface>,
+ Terminator, AttrSizedOperandSegments
+]> {
+ let summary = "try_call operation";
+
+ let description = [{
+ Mostly similar to cir.call but requires two destination
+ branches, one for handling exceptions in case its thrown and
+ the other one to follow on regular control-flow.
+
+ Example:
+
+ ```mlir
+ // Direct call
+ %result = cir.try_call @division(%a, %b) ^continue, ^landing_pad
+ : (f32, f32) -> f32
+ ```
+ }];
+
+ let arguments = !con((ins
+ Variadic<CIR_AnyType>:$contOperands,
+ Variadic<CIR_AnyType>:$landingPadOperands
+ ), commonArgs);
+
+ let results = (outs Optional<CIR_AnyType>:$result);
+ let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad);
+
+ let skipDefaultBuilders = 1;
+
+ let builders = [
+ OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
+ "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+ CArg<"mlir::ValueRange", "{}">:$operands,
+ CArg<"mlir::ValueRange", "{}">:$contOperands,
+ CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
+ CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+ $_state.addOperands(operands);
+ if (callee)
+ $_state.addAttribute("callee", callee);
+ if (resType && !isa<VoidType>(resType))
+ $_state.addTypes(resType);
+
+ $_state.addAttribute("side_effect",
+ SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+ // Handle branches
+ $_state.addOperands(contOperands);
+ $_state.addOperands(landingPadOperands);
+ // The TryCall ODS layout is: cont, landing_pad, operands.
+ llvm::copy(::llvm::ArrayRef<int32_t>({
+ static_cast<int32_t>(contOperands.size()),
+ static_cast<int32_t>(landingPadOperands.size()),
+ static_cast<int32_t>(operands.size())
+ }),
+ odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+ $_state.addSuccessors(cont);
+ $_state.addSuccessors(landing_pad);
+ }]>,
+ OpBuilder<(ins "mlir::Value":$ind_target,
+ "FuncType":$fn_type,
+ "mlir::Block *":$cont, "mlir::Block *":$landing_pad,
+ CArg<"mlir::ValueRange", "{}">:$operands,
+ CArg<"mlir::ValueRange", "{}">:$contOperands,
+ CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
+ CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
+ ::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
+ finalCallOperands.append(operands.begin(), operands.end());
+ $_state.addOperands(finalCallOperands);
+
+ if (!fn_type.hasVoidReturn())
+ $_state.addTypes(fn_type.getReturnType());
+
+ $_state.addAttribute("side_effect",
+ SideEffectAttr::get($_builder.getContext(), sideEffect));
+
+ // Handle branches
+ $_state.addOperands(contOperands);
+ $_state.addOperands(landingPadOperands);
+ // The TryCall ODS layout is: cont, landing_pad, operands.
+ llvm::copy(::llvm::ArrayRef<int32_t>({
+ static_cast<int32_t>(contOperands.size()),
+ static_cast<int32_t>(landingPadOperands.size()),
+ static_cast<int32_t>(finalCallOperands.size())
+ }),
+ odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
+ $_state.addSuccessors(cont);
+ $_state.addSuccessors(landing_pad);
+ }]>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 2d2ef422bfaef..11074af3ef127 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -701,13 +701,78 @@ unsigned cir::CallOp::getNumArgOperands() {
return this->getOperation()->getNumOperands();
}
+static mlir::ParseResult
+parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+ &continueOperands,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
+ &landingPadOperands,
+ llvm::SmallVectorImpl<mlir::Type> &continueTypes,
+ llvm::SmallVectorImpl<mlir::Type> &landingPadTypes,
+ llvm::SMLoc &continueOperandsLoc,
+ llvm::SMLoc &landingPadOperandsLoc) {
+ mlir::Block *continueSuccessor = nullptr;
+ mlir::Block *landingPadSuccessor = nullptr;
+
+ if (parser.parseSuccessor(continueSuccessor))
+ return mlir::failure();
+
+ if (mlir::succeeded(parser.parseOptionalLParen())) {
+ continueOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(continueOperands))
+ return mlir::failure();
+ if (parser.parseColon())
+ return mlir::failure();
+
+ if (parser.parseTypeList(continueTypes))
+ return mlir::failure();
+ if (parser.parseRParen())
+ return mlir::failure();
+ }
+
+ if (parser.parseComma())
+ return mlir::failure();
+
+ if (parser.parseSuccessor(landingPadSuccessor))
+ return mlir::failure();
+
+ if (mlir::succeeded(parser.parseOptionalLParen())) {
+ landingPadOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperandList(landingPadOperands))
+ return mlir::failure();
+ if (parser.parseColon())
+ return mlir::failure();
+
+ if (parser.parseTypeList(landingPadTypes))
+ return mlir::failure();
+ if (parser.parseRParen())
+ return mlir::failure();
+ }
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return mlir::failure();
+
+ result.addSuccessors(continueSuccessor);
+ result.addSuccessors(landingPadSuccessor);
+ return mlir::success();
+}
+
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
- mlir::OperationState &result) {
+ mlir::OperationState &result,
+ bool hasDestinationBlocks = false) {
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
llvm::SMLoc opsLoc;
mlir::FlatSymbolRefAttr calleeAttr;
llvm::ArrayRef<mlir::Type> allResultTypes;
+ // TryCall control flow related
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> continueOperands;
+ llvm::SMLoc continueOperandsLoc;
+ llvm::SmallVector<mlir::Type, 1> continueTypes;
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> landingPadOperands;
+ llvm::SMLoc landingPadOperandsLoc;
+ llvm::SmallVector<mlir::Type, 1> landingPadTypes;
+
// If we cannot parse a string callee, it means this is an indirect call.
if (!parser
.parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
@@ -729,6 +794,14 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
if (parser.parseRParen())
return mlir::failure();
+ if (hasDestinationBlocks &&
+ parseTryCallBranches(parser, result, continueOperands, landingPadOperands,
+ continueTypes, landingPadTypes, continueOperandsLoc,
+ landingPadOperandsLoc)
+ .failed()) {
+ return ::mlir::failure();
+ }
+
if (parser.parseOptionalKeyword("nothrow").succeeded())
result.addAttribute(CIRDialect::getNoThrowAttrName(),
mlir::UnitAttr::get(parser.getContext()));
@@ -761,6 +834,24 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
return mlir::failure();
+ if (hasDestinationBlocks) {
+ // The TryCall ODS layout is: cont, landing_pad, operands.
+ llvm::copy(::llvm::ArrayRef<int32_t>(
+ {static_cast<int32_t>(continueOperands.size()),
+ static_cast<int32_t>(landingPadOperands.size()),
+ static_cast<int32_t>(ops.size())}),
+ result.getOrAddProperties<cir::TryCallOp::Properties>()
+ .operandSegmentSizes.begin());
+
+ if (parser.resolveOperands(continueOperands, continueTypes,
+ continueOperandsLoc, result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(landingPadOperands, landingPadTypes,
+ landingPadOperandsLoc, result.operands))
+ return ::mlir::failure();
+ }
+
return mlir::success();
}
@@ -768,7 +859,9 @@ static void printCallCommon(mlir::Operation *op,
mlir::FlatSymbolRefAttr calleeSym,
mlir::Value indirectCallee,
mlir::OpAsmPrinter &printer, bool isNothrow,
- cir::SideEffect sideEffect) {
+ cir::SideEffect sideEffect,
+ mlir::Block *cont = nullptr,
+ mlir::Block *landingPad = nullptr) {
printer << ' ';
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -782,8 +875,35 @@ static void printCallCommon(mlir::Operation *op,
assert(indirectCallee);
printer << indirectCallee;
}
+
printer << "(" << ops << ")";
+ if (cont) {
+ assert(landingPad && "expected two successors");
+ auto tryCall = dyn_cast<cir::TryCallOp>(op);
+ assert(tryCall && "regular calls do not branch");
+ printer << ' ' << tryCall.getCont();
+ if (!tryCall.getContOperands().empty()) {
+ printer << "(";
+ printer << tryCall.getContOperands();
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << tryCall.getContOperands().getTypes();
+ printer << ")";
+ }
+ printer << ",";
+ printer << ' ';
+ printer << tryCall.getLandingPad();
+ if (!tryCall.getLandingPadOperands().empty()) {
+ printer << "(";
+ printer << tryCall.getLandingPadOperands();
+ printer << ' ' << ":";
+ printer << ' ';
+ printer << tryCall.getLandingPadOperands().getTypes();
+ printer << ")";
+ }
+ }
+
if (isNothrow)
printer << " nothrow";
@@ -793,10 +913,11 @@ static void printCallCommon(mlir::Operation *op,
printer << ")";
}
- printer.printOptionalAttrDict(op->getAttrs(),
- {CIRDialect::getCalleeAttrName(),
- CIRDialect::getNoThrowAttrName(),
- CIRDialect::getSideEffectAttrName()});
+ llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = {
+ CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
+ CIRDialect::getSideEffectAttrName(),
+ CIRDialect::getOperandSegmentSizesAttrName()};
+ printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
printer << " : ";
printer.printFunctionalType(op->getOperands().getTypes(),
@@ -878,6 +999,70 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyCallCommInSymbolUses(*this, symbolTable);
}
+//===----------------------------------------------------------------------===//
+// TryCallOp
+//===----------------------------------------------------------------------===//
+
+mlir::OperandRange cir::TryCallOp::getArgOperands() {
+ if (isIndirect())
+ return getArgs().drop_front(1);
+ return getArgs();
+}
+
+mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
+ mlir::MutableOperandRange args = getArgsMutable();
+ if (isIndirect())
+ return args.slice(1, args.size() - 1);
+ return args;
+}
+
+mlir::Value cir::TryCallOp::getIndirectCall() {
+ assert(isIndirect());
+ return getOperand(0);
+}
+
+/// Return the operand at index 'i'.
+Value cir::TryCallOp::getArgOperand(unsigned i) {
+ if (isIndirect())
+ ++i;
+ return getOperand(i);
+}
+
+/// Return the number of operands.
+unsigned cir::TryCallOp::getNumArgOperands() {
+ if (isIndirect())
+ return this->getOperation()->getNumOperands() - 1;
+ return this->getOperation()->getNumOperands();
+}
+
+LogicalResult
+cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ return verifyCallCommInSymbolUses(*this, symbolTable);
+}
+
+mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
+}
+
+void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
+ mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
+ cir::SideEffect sideEffect = getSideEffect();
+ printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
+ sideEffect, getCont(), getLandingPad());
+}
+
+mlir::SuccessorOperands cir::TryCallOp::getSuccessorOperands(unsigned index) {
+ assert(index < getNumSuccessors() && "invalid successor index");
+ if (index == 0)
+ return SuccessorOperands(getContOperandsMutable());
+ if (index == 1)
+ return SuccessorOperands(getLandingPadOperandsMutable());
+
+ // index == 2
+ return SuccessorOperands(getArgOperandsMutable());
+}
+
//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 5a6193fa8d840..12f3db01c77d8 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1385,7 +1385,9 @@ static mlir::LogicalResult
rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
- mlir::FlatSymbolRefAttr calleeAttr) {
+ mlir::FlatSymbolRefAttr calleeAttr,
+ mlir::Block *continueBlock = nullptr,
+ mlir::Block *landingPadBlock = nullptr) {
llvm::SmallVector<mlir::Type, 8> llvmResults;
mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes();
auto call = cast<cir::CIRCallOpInterface>(op);
@@ -1414,7 +1416,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
llvmFnTy = converter->convertType<mlir::LLVM::LLVMFunctionType>(
fn.getFunctionType());
assert(llvmFnTy && "Failed to convert function type");
- } else if (auto alias = mlir::cast<mlir::LLVM::AliasOp>(callee)) {
+ } else if (auto alias = mlir::dyn_cast<mlir::LLVM::AliasOp>(callee)) {
// If the callee was an alias. In that case,
// we need to prepend the address of the alias to the operands. The
// way aliases work in the LLVM dialect is a little counter-intuitive.
@@ -1452,17 +1454,21 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
converter->convertType(calleeFuncTy));
}
- assert(!cir::MissingFeatures::opCallLandingPad());
- assert(!cir::MissingFeatures::opCallContinueBlock());
assert(!cir::MissingFeatures::opCallCallConv());
+ if (landingPadBlock) {
+ rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
+ op, llvmFnTy, calleeAttr, callOperands, continueBlock,
+ mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
+ return mlir::success();
+ }
+
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
op, llvmFnTy, calleeAttr, callOperands);
if (memoryEffects)
newOp.setMemoryEffectsAttr(memoryEffects);
newOp.setNoUnwind(noUnwind);
newOp.setWillReturn(willReturn);
-
return mlir::success();
}
@@ -1473,6 +1479,14 @@ mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite(
getTypeConverter(), op.getCalleeAttr());
}
+mlir::LogicalResult CIRToLLVMTryCallOpLowering::matchAndRewrite(
+ cir::TryCallOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter,
+ getTypeConverter(), op.getCalleeAttr(),
+ op.getCont(), op.getLandingPad());
+}
+
mlir::LogicalResult CIRToLLVMReturnAddrOpLowering::matchAndRewrite(
cir::ReturnAddrOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/IR/try-call.cir b/clang/test/CIR/IR/try-call.cir
new file mode 100644
index 0000000000000..6c23d3add15c8
--- /dev/null
+++ b/clang/test/CIR/IR/try-call.cir
@@ -0,0 +1,31 @@
+// RUN: cir-opt %s --verify-roundtrip | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+
+cir.func private @division(%a: !s32i, %b: !s32i) -> !s32i
+
+cir.func @flatten_structure_with_try_call_op() {
+ %a = cir.const #cir.int<1> : !s32i
+ %b = cir.const #cir.int<2> : !s32i
+ %3 = cir.try_call @division(%a, %b) ^continue, ^landing_pad : (!s32i, !s32i) -> !s32i
+ ^continue:
+ cir.br ^landing_pad
+ ^landing_pad:
+ cir.return
+}
+
+// CHECK: cir.func private @division(!s32i, !s32i) -> !s32i
+
+// CHECK: cir.func @flatten_structure_with_try_call_op() {
+// CHECK-NEXT: %[[CONST_0:.*]] = cir.const #cir.int<1> : !s32i
+// CHECK-NEXT: %[[CONST_1:.*]] = cir.const #cir.int<2> : !s32i
+// CHECK-NEXT: %[[CALL:.*]] = cir.try_call @division(%0, %1) ^[[CONTINUE:.*]], ^[[LANDING_PAD:.*]] : (!s32i, !s32i) -> !s32i
+// CHECK-NEXT: ^[[CONTINUE]]:
+// CHECK-NEXT: cir.br ^[[LANDING_PAD]]
+// CHECK-NEXT: ^[[LANDING_PAD]]:
+// CHECK-NEXT: cir.return
+// CHECK-NEXT: }
+
+}
|
| Variadic<CIR_AnyType>:$contOperands, | ||
| Variadic<CIR_AnyType>:$landingPadOperands |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are these arguments? Are they operands that are passed to the continue and landing pad blocks or are they the continue and landing pad blocks themselves? If the latter, I don't understand why they are variadic and any type. If the former, are they even used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are operands that are passed to the normal and unwind destination blocks, I am trying to think of cases that the unwind dest will require operands, but for now, I will remove them to simplify the parser and the op, and if we need them, we can add them back, I will spend sometime update parser, verifier, op and then we can have another round of review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are arguments for the destinations, similar to what we have in BrCond when we pass the exception to pass it to another block. I can't see any case in TryCallOp that needs operands for destinations, but confirm the BranchOpInterface we need to have them to not lose some verification for now 🤔.
I was thinking of keeping them in the op but ignore them from builder, parser, and printer until we need them or if we realize we will not need them in the future, we can copy the verify to TryCallOp itself and remove those operands 🤔, what do you think?
@andykaylor @bcardosolopes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking of keeping them in the op but ignore them from builder, parser, and printer until we need them or if we realize we will not need them in the future, we can copy the verify to TryCallOp itself and remove those operands 🤔, what do you think? @andykaylor @bcardosolopes
In such case I would remove them entirely. Plese do not add any operands withour printer/parser + verifier and tests.
| ), commonArgs); | ||
|
|
||
| let results = (outs Optional<CIR_AnyType>:$result); | ||
| let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The destinations are referred to as the normal destination and the unwind destination in the LLVM invoke instruction. I think it would be helpful to use the same terminology here.
| assert(!cir::MissingFeatures::opCallCallConv()); | ||
|
|
||
| if (landingPadBlock) { | ||
| rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you assert that the continue block is not null here?
| const mlir::TypeConverter *converter, | ||
| mlir::FlatSymbolRefAttr calleeAttr) { | ||
| mlir::FlatSymbolRefAttr calleeAttr, | ||
| mlir::Block *continueBlock = nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The lowering isn't tested in this PR. Maybe it should be omitted until it is needed. Otherwise, it needs a test.
912d307 to
d03d7ea
Compare
🐧 Linux x64 Test Results
Failed Tests(click on a test name to see its output) ClangClang.CIR/IR/try-call.cirClang.CIR/IR/try-call.cirIf these failures are unrelated to your changes (for example tests are broken or flaky at HEAD), please open an issue at https://github.com/llvm/llvm-project/issues and add the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for the future @andykaylor @bcardosolopes I believe we should keep apart 2 versions: cir.scf.try_call with regions (sctructured control flow) and cir.try_call with blocks as destinations. Might be the similar issue for other operations, we should try to separate structured and unstructured paradigms.
@xlauko In general, I agree with what you're saying, though in this case there is no direct structured equivalent. Before flattening the |
d03d7ea to
a70a81f
Compare
| CIRDialect::getNoThrowAttrName(), | ||
| CIRDialect::getSideEffectAttrName()}); | ||
|
|
||
| llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = { | |
| llvm::SmallVector<::llvm::StringRef> elidedAttrs = { |
Having an initializer makes this unnecessary.
Upstream TryCall Op as a prerequisite for Try Catch work
Issue #154992