Skip to content

Commit c6e63ba

Browse files
committed
[CIR] Upstream TryCallOp
1 parent 3564791 commit c6e63ba

File tree

5 files changed

+335
-12
lines changed

5 files changed

+335
-12
lines changed

clang/include/clang/CIR/Dialect/IR/CIRDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
4444
static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
4545
static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
4646
static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
47+
static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
4748

4849
void registerAttributes();
4950
void registerTypes();

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2728,7 +2728,7 @@ def CIR_LLVMIntrinsicCallOp : CIR_Op<"call_llvm_intrinsic"> {
27282728
}
27292729

27302730
//===----------------------------------------------------------------------===//
2731-
// CallOp
2731+
// CallOp and TryCallOp
27322732
//===----------------------------------------------------------------------===//
27332733

27342734
def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2855,6 +2855,98 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
28552855
];
28562856
}
28572857

2858+
def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
2859+
DeclareOpInterfaceMethods<BranchOpInterface>,
2860+
Terminator, AttrSizedOperandSegments
2861+
]> {
2862+
let summary = "try_call operation";
2863+
2864+
let description = [{
2865+
Mostly similar to cir.call but requires two destination
2866+
branches, one for handling exceptions in case its thrown and
2867+
the other one to follow on regular control-flow.
2868+
2869+
Example:
2870+
2871+
```mlir
2872+
// Direct call
2873+
%result = cir.try_call @division(%a, %b) ^continue, ^landing_pad
2874+
: (f32, f32) -> f32
2875+
```
2876+
}];
2877+
2878+
let arguments = !con((ins
2879+
Variadic<CIR_AnyType>:$contOperands,
2880+
Variadic<CIR_AnyType>:$landingPadOperands
2881+
), commonArgs);
2882+
2883+
let results = (outs Optional<CIR_AnyType>:$result);
2884+
let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad);
2885+
2886+
let skipDefaultBuilders = 1;
2887+
2888+
let builders = [
2889+
OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
2890+
"mlir::Block *":$cont, "mlir::Block *":$landing_pad,
2891+
CArg<"mlir::ValueRange", "{}">:$operands,
2892+
CArg<"mlir::ValueRange", "{}">:$contOperands,
2893+
CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
2894+
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
2895+
$_state.addOperands(operands);
2896+
if (callee)
2897+
$_state.addAttribute("callee", callee);
2898+
if (resType && !isa<VoidType>(resType))
2899+
$_state.addTypes(resType);
2900+
2901+
$_state.addAttribute("side_effect",
2902+
SideEffectAttr::get($_builder.getContext(), sideEffect));
2903+
2904+
// Handle branches
2905+
$_state.addOperands(contOperands);
2906+
$_state.addOperands(landingPadOperands);
2907+
// The TryCall ODS layout is: cont, landing_pad, operands.
2908+
llvm::copy(::llvm::ArrayRef<int32_t>({
2909+
static_cast<int32_t>(contOperands.size()),
2910+
static_cast<int32_t>(landingPadOperands.size()),
2911+
static_cast<int32_t>(operands.size())
2912+
}),
2913+
odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
2914+
$_state.addSuccessors(cont);
2915+
$_state.addSuccessors(landing_pad);
2916+
}]>,
2917+
OpBuilder<(ins "mlir::Value":$ind_target,
2918+
"FuncType":$fn_type,
2919+
"mlir::Block *":$cont, "mlir::Block *":$landing_pad,
2920+
CArg<"mlir::ValueRange", "{}">:$operands,
2921+
CArg<"mlir::ValueRange", "{}">:$contOperands,
2922+
CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
2923+
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
2924+
::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
2925+
finalCallOperands.append(operands.begin(), operands.end());
2926+
$_state.addOperands(finalCallOperands);
2927+
2928+
if (!fn_type.hasVoidReturn())
2929+
$_state.addTypes(fn_type.getReturnType());
2930+
2931+
$_state.addAttribute("side_effect",
2932+
SideEffectAttr::get($_builder.getContext(), sideEffect));
2933+
2934+
// Handle branches
2935+
$_state.addOperands(contOperands);
2936+
$_state.addOperands(landingPadOperands);
2937+
// The TryCall ODS layout is: cont, landing_pad, operands.
2938+
llvm::copy(::llvm::ArrayRef<int32_t>({
2939+
static_cast<int32_t>(contOperands.size()),
2940+
static_cast<int32_t>(landingPadOperands.size()),
2941+
static_cast<int32_t>(finalCallOperands.size())
2942+
}),
2943+
odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
2944+
$_state.addSuccessors(cont);
2945+
$_state.addSuccessors(landing_pad);
2946+
}]>
2947+
];
2948+
}
2949+
28582950
//===----------------------------------------------------------------------===//
28592951
// AwaitOp
28602952
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 191 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -715,13 +715,78 @@ unsigned cir::CallOp::getNumArgOperands() {
715715
return this->getOperation()->getNumOperands();
716716
}
717717

718+
static mlir::ParseResult
719+
parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result,
720+
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
721+
&continueOperands,
722+
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
723+
&landingPadOperands,
724+
llvm::SmallVectorImpl<mlir::Type> &continueTypes,
725+
llvm::SmallVectorImpl<mlir::Type> &landingPadTypes,
726+
llvm::SMLoc &continueOperandsLoc,
727+
llvm::SMLoc &landingPadOperandsLoc) {
728+
mlir::Block *continueSuccessor = nullptr;
729+
mlir::Block *landingPadSuccessor = nullptr;
730+
731+
if (parser.parseSuccessor(continueSuccessor))
732+
return mlir::failure();
733+
734+
if (mlir::succeeded(parser.parseOptionalLParen())) {
735+
continueOperandsLoc = parser.getCurrentLocation();
736+
if (parser.parseOperandList(continueOperands))
737+
return mlir::failure();
738+
if (parser.parseColon())
739+
return mlir::failure();
740+
741+
if (parser.parseTypeList(continueTypes))
742+
return mlir::failure();
743+
if (parser.parseRParen())
744+
return mlir::failure();
745+
}
746+
747+
if (parser.parseComma())
748+
return mlir::failure();
749+
750+
if (parser.parseSuccessor(landingPadSuccessor))
751+
return mlir::failure();
752+
753+
if (mlir::succeeded(parser.parseOptionalLParen())) {
754+
landingPadOperandsLoc = parser.getCurrentLocation();
755+
if (parser.parseOperandList(landingPadOperands))
756+
return mlir::failure();
757+
if (parser.parseColon())
758+
return mlir::failure();
759+
760+
if (parser.parseTypeList(landingPadTypes))
761+
return mlir::failure();
762+
if (parser.parseRParen())
763+
return mlir::failure();
764+
}
765+
766+
if (parser.parseOptionalAttrDict(result.attributes))
767+
return mlir::failure();
768+
769+
result.addSuccessors(continueSuccessor);
770+
result.addSuccessors(landingPadSuccessor);
771+
return mlir::success();
772+
}
773+
718774
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
719-
mlir::OperationState &result) {
775+
mlir::OperationState &result,
776+
bool hasDestinationBlocks = false) {
720777
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
721778
llvm::SMLoc opsLoc;
722779
mlir::FlatSymbolRefAttr calleeAttr;
723780
llvm::ArrayRef<mlir::Type> allResultTypes;
724781

782+
// TryCall control flow related
783+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> continueOperands;
784+
llvm::SMLoc continueOperandsLoc;
785+
llvm::SmallVector<mlir::Type, 1> continueTypes;
786+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> landingPadOperands;
787+
llvm::SMLoc landingPadOperandsLoc;
788+
llvm::SmallVector<mlir::Type, 1> landingPadTypes;
789+
725790
// If we cannot parse a string callee, it means this is an indirect call.
726791
if (!parser
727792
.parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
@@ -743,6 +808,14 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
743808
if (parser.parseRParen())
744809
return mlir::failure();
745810

811+
if (hasDestinationBlocks &&
812+
parseTryCallBranches(parser, result, continueOperands, landingPadOperands,
813+
continueTypes, landingPadTypes, continueOperandsLoc,
814+
landingPadOperandsLoc)
815+
.failed()) {
816+
return ::mlir::failure();
817+
}
818+
746819
if (parser.parseOptionalKeyword("nothrow").succeeded())
747820
result.addAttribute(CIRDialect::getNoThrowAttrName(),
748821
mlir::UnitAttr::get(parser.getContext()));
@@ -775,14 +848,34 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
775848
if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
776849
return mlir::failure();
777850

851+
if (hasDestinationBlocks) {
852+
// The TryCall ODS layout is: cont, landing_pad, operands.
853+
llvm::copy(::llvm::ArrayRef<int32_t>(
854+
{static_cast<int32_t>(continueOperands.size()),
855+
static_cast<int32_t>(landingPadOperands.size()),
856+
static_cast<int32_t>(ops.size())}),
857+
result.getOrAddProperties<cir::TryCallOp::Properties>()
858+
.operandSegmentSizes.begin());
859+
860+
if (parser.resolveOperands(continueOperands, continueTypes,
861+
continueOperandsLoc, result.operands))
862+
return ::mlir::failure();
863+
864+
if (parser.resolveOperands(landingPadOperands, landingPadTypes,
865+
landingPadOperandsLoc, result.operands))
866+
return ::mlir::failure();
867+
}
868+
778869
return mlir::success();
779870
}
780871

781872
static void printCallCommon(mlir::Operation *op,
782873
mlir::FlatSymbolRefAttr calleeSym,
783874
mlir::Value indirectCallee,
784875
mlir::OpAsmPrinter &printer, bool isNothrow,
785-
cir::SideEffect sideEffect) {
876+
cir::SideEffect sideEffect,
877+
mlir::Block *cont = nullptr,
878+
mlir::Block *landingPad = nullptr) {
786879
printer << ' ';
787880

788881
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -796,8 +889,35 @@ static void printCallCommon(mlir::Operation *op,
796889
assert(indirectCallee);
797890
printer << indirectCallee;
798891
}
892+
799893
printer << "(" << ops << ")";
800894

895+
if (cont) {
896+
assert(landingPad && "expected two successors");
897+
auto tryCall = dyn_cast<cir::TryCallOp>(op);
898+
assert(tryCall && "regular calls do not branch");
899+
printer << ' ' << tryCall.getCont();
900+
if (!tryCall.getContOperands().empty()) {
901+
printer << "(";
902+
printer << tryCall.getContOperands();
903+
printer << ' ' << ":";
904+
printer << ' ';
905+
printer << tryCall.getContOperands().getTypes();
906+
printer << ")";
907+
}
908+
printer << ",";
909+
printer << ' ';
910+
printer << tryCall.getLandingPad();
911+
if (!tryCall.getLandingPadOperands().empty()) {
912+
printer << "(";
913+
printer << tryCall.getLandingPadOperands();
914+
printer << ' ' << ":";
915+
printer << ' ';
916+
printer << tryCall.getLandingPadOperands().getTypes();
917+
printer << ")";
918+
}
919+
}
920+
801921
if (isNothrow)
802922
printer << " nothrow";
803923

@@ -807,10 +927,11 @@ static void printCallCommon(mlir::Operation *op,
807927
printer << ")";
808928
}
809929

810-
printer.printOptionalAttrDict(op->getAttrs(),
811-
{CIRDialect::getCalleeAttrName(),
812-
CIRDialect::getNoThrowAttrName(),
813-
CIRDialect::getSideEffectAttrName()});
930+
llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = {
931+
CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
932+
CIRDialect::getSideEffectAttrName(),
933+
CIRDialect::getOperandSegmentSizesAttrName()};
934+
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
814935

815936
printer << " : ";
816937
printer.printFunctionalType(op->getOperands().getTypes(),
@@ -892,6 +1013,70 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
8921013
return verifyCallCommInSymbolUses(*this, symbolTable);
8931014
}
8941015

1016+
//===----------------------------------------------------------------------===//
1017+
// TryCallOp
1018+
//===----------------------------------------------------------------------===//
1019+
1020+
mlir::OperandRange cir::TryCallOp::getArgOperands() {
1021+
if (isIndirect())
1022+
return getArgs().drop_front(1);
1023+
return getArgs();
1024+
}
1025+
1026+
mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
1027+
mlir::MutableOperandRange args = getArgsMutable();
1028+
if (isIndirect())
1029+
return args.slice(1, args.size() - 1);
1030+
return args;
1031+
}
1032+
1033+
mlir::Value cir::TryCallOp::getIndirectCall() {
1034+
assert(isIndirect());
1035+
return getOperand(0);
1036+
}
1037+
1038+
/// Return the operand at index 'i'.
1039+
Value cir::TryCallOp::getArgOperand(unsigned i) {
1040+
if (isIndirect())
1041+
++i;
1042+
return getOperand(i);
1043+
}
1044+
1045+
/// Return the number of operands.
1046+
unsigned cir::TryCallOp::getNumArgOperands() {
1047+
if (isIndirect())
1048+
return this->getOperation()->getNumOperands() - 1;
1049+
return this->getOperation()->getNumOperands();
1050+
}
1051+
1052+
LogicalResult
1053+
cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1054+
return verifyCallCommInSymbolUses(*this, symbolTable);
1055+
}
1056+
1057+
mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
1058+
mlir::OperationState &result) {
1059+
return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
1060+
}
1061+
1062+
void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
1063+
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1064+
cir::SideEffect sideEffect = getSideEffect();
1065+
printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1066+
sideEffect, getCont(), getLandingPad());
1067+
}
1068+
1069+
mlir::SuccessorOperands cir::TryCallOp::getSuccessorOperands(unsigned index) {
1070+
assert(index < getNumSuccessors() && "invalid successor index");
1071+
if (index == 0)
1072+
return SuccessorOperands(getContOperandsMutable());
1073+
if (index == 1)
1074+
return SuccessorOperands(getLandingPadOperandsMutable());
1075+
1076+
// index == 2
1077+
return SuccessorOperands(getArgOperandsMutable());
1078+
}
1079+
8951080
//===----------------------------------------------------------------------===//
8961081
// ReturnOp
8971082
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)