Skip to content

Commit 912d307

Browse files
committed
[CIR] Upstream TryCallOp
1 parent 5142707 commit 912d307

File tree

5 files changed

+335
-12
lines changed

5 files changed

+335
-12
lines changed

clang/include/clang/CIR/Dialect/IR/CIRDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def CIR_Dialect : Dialect {
4444
static llvm::StringRef getModuleLevelAsmAttrName() { return "cir.module_asm"; }
4545
static llvm::StringRef getGlobalCtorsAttrName() { return "cir.global_ctors"; }
4646
static llvm::StringRef getGlobalDtorsAttrName() { return "cir.global_dtors"; }
47+
static llvm::StringRef getOperandSegmentSizesAttrName() { return "operandSegmentSizes"; }
4748

4849
void registerAttributes();
4950
void registerTypes();

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2580,7 +2580,7 @@ def CIR_FuncOp : CIR_Op<"func", [
25802580
}
25812581

25822582
//===----------------------------------------------------------------------===//
2583-
// CallOp
2583+
// CallOp and TryCallOp
25842584
//===----------------------------------------------------------------------===//
25852585

25862586
def CIR_SideEffect : CIR_I32EnumAttr<
@@ -2707,6 +2707,98 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
27072707
];
27082708
}
27092709

2710+
def CIR_TryCallOp : CIR_CallOpBase<"try_call",[
2711+
DeclareOpInterfaceMethods<BranchOpInterface>,
2712+
Terminator, AttrSizedOperandSegments
2713+
]> {
2714+
let summary = "try_call operation";
2715+
2716+
let description = [{
2717+
Mostly similar to cir.call but requires two destination
2718+
branches, one for handling exceptions in case its thrown and
2719+
the other one to follow on regular control-flow.
2720+
2721+
Example:
2722+
2723+
```mlir
2724+
// Direct call
2725+
%result = cir.try_call @division(%a, %b) ^continue, ^landing_pad
2726+
: (f32, f32) -> f32
2727+
```
2728+
}];
2729+
2730+
let arguments = !con((ins
2731+
Variadic<CIR_AnyType>:$contOperands,
2732+
Variadic<CIR_AnyType>:$landingPadOperands
2733+
), commonArgs);
2734+
2735+
let results = (outs Optional<CIR_AnyType>:$result);
2736+
let successors = (successor AnySuccessor:$cont, AnySuccessor:$landing_pad);
2737+
2738+
let skipDefaultBuilders = 1;
2739+
2740+
let builders = [
2741+
OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType,
2742+
"mlir::Block *":$cont, "mlir::Block *":$landing_pad,
2743+
CArg<"mlir::ValueRange", "{}">:$operands,
2744+
CArg<"mlir::ValueRange", "{}">:$contOperands,
2745+
CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
2746+
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
2747+
$_state.addOperands(operands);
2748+
if (callee)
2749+
$_state.addAttribute("callee", callee);
2750+
if (resType && !isa<VoidType>(resType))
2751+
$_state.addTypes(resType);
2752+
2753+
$_state.addAttribute("side_effect",
2754+
SideEffectAttr::get($_builder.getContext(), sideEffect));
2755+
2756+
// Handle branches
2757+
$_state.addOperands(contOperands);
2758+
$_state.addOperands(landingPadOperands);
2759+
// The TryCall ODS layout is: cont, landing_pad, operands.
2760+
llvm::copy(::llvm::ArrayRef<int32_t>({
2761+
static_cast<int32_t>(contOperands.size()),
2762+
static_cast<int32_t>(landingPadOperands.size()),
2763+
static_cast<int32_t>(operands.size())
2764+
}),
2765+
odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
2766+
$_state.addSuccessors(cont);
2767+
$_state.addSuccessors(landing_pad);
2768+
}]>,
2769+
OpBuilder<(ins "mlir::Value":$ind_target,
2770+
"FuncType":$fn_type,
2771+
"mlir::Block *":$cont, "mlir::Block *":$landing_pad,
2772+
CArg<"mlir::ValueRange", "{}">:$operands,
2773+
CArg<"mlir::ValueRange", "{}">:$contOperands,
2774+
CArg<"mlir::ValueRange", "{}">:$landingPadOperands,
2775+
CArg<"SideEffect", "SideEffect::All">:$sideEffect), [{
2776+
::llvm::SmallVector<mlir::Value, 4> finalCallOperands({ind_target});
2777+
finalCallOperands.append(operands.begin(), operands.end());
2778+
$_state.addOperands(finalCallOperands);
2779+
2780+
if (!fn_type.hasVoidReturn())
2781+
$_state.addTypes(fn_type.getReturnType());
2782+
2783+
$_state.addAttribute("side_effect",
2784+
SideEffectAttr::get($_builder.getContext(), sideEffect));
2785+
2786+
// Handle branches
2787+
$_state.addOperands(contOperands);
2788+
$_state.addOperands(landingPadOperands);
2789+
// The TryCall ODS layout is: cont, landing_pad, operands.
2790+
llvm::copy(::llvm::ArrayRef<int32_t>({
2791+
static_cast<int32_t>(contOperands.size()),
2792+
static_cast<int32_t>(landingPadOperands.size()),
2793+
static_cast<int32_t>(finalCallOperands.size())
2794+
}),
2795+
odsState.getOrAddProperties<Properties>().operandSegmentSizes.begin());
2796+
$_state.addSuccessors(cont);
2797+
$_state.addSuccessors(landing_pad);
2798+
}]>
2799+
];
2800+
}
2801+
27102802
//===----------------------------------------------------------------------===//
27112803
// CopyOp
27122804
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 191 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -701,13 +701,78 @@ unsigned cir::CallOp::getNumArgOperands() {
701701
return this->getOperation()->getNumOperands();
702702
}
703703

704+
static mlir::ParseResult
705+
parseTryCallBranches(mlir::OpAsmParser &parser, mlir::OperationState &result,
706+
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
707+
&continueOperands,
708+
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand>
709+
&landingPadOperands,
710+
llvm::SmallVectorImpl<mlir::Type> &continueTypes,
711+
llvm::SmallVectorImpl<mlir::Type> &landingPadTypes,
712+
llvm::SMLoc &continueOperandsLoc,
713+
llvm::SMLoc &landingPadOperandsLoc) {
714+
mlir::Block *continueSuccessor = nullptr;
715+
mlir::Block *landingPadSuccessor = nullptr;
716+
717+
if (parser.parseSuccessor(continueSuccessor))
718+
return mlir::failure();
719+
720+
if (mlir::succeeded(parser.parseOptionalLParen())) {
721+
continueOperandsLoc = parser.getCurrentLocation();
722+
if (parser.parseOperandList(continueOperands))
723+
return mlir::failure();
724+
if (parser.parseColon())
725+
return mlir::failure();
726+
727+
if (parser.parseTypeList(continueTypes))
728+
return mlir::failure();
729+
if (parser.parseRParen())
730+
return mlir::failure();
731+
}
732+
733+
if (parser.parseComma())
734+
return mlir::failure();
735+
736+
if (parser.parseSuccessor(landingPadSuccessor))
737+
return mlir::failure();
738+
739+
if (mlir::succeeded(parser.parseOptionalLParen())) {
740+
landingPadOperandsLoc = parser.getCurrentLocation();
741+
if (parser.parseOperandList(landingPadOperands))
742+
return mlir::failure();
743+
if (parser.parseColon())
744+
return mlir::failure();
745+
746+
if (parser.parseTypeList(landingPadTypes))
747+
return mlir::failure();
748+
if (parser.parseRParen())
749+
return mlir::failure();
750+
}
751+
752+
if (parser.parseOptionalAttrDict(result.attributes))
753+
return mlir::failure();
754+
755+
result.addSuccessors(continueSuccessor);
756+
result.addSuccessors(landingPadSuccessor);
757+
return mlir::success();
758+
}
759+
704760
static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
705-
mlir::OperationState &result) {
761+
mlir::OperationState &result,
762+
bool hasDestinationBlocks = false) {
706763
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> ops;
707764
llvm::SMLoc opsLoc;
708765
mlir::FlatSymbolRefAttr calleeAttr;
709766
llvm::ArrayRef<mlir::Type> allResultTypes;
710767

768+
// TryCall control flow related
769+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> continueOperands;
770+
llvm::SMLoc continueOperandsLoc;
771+
llvm::SmallVector<mlir::Type, 1> continueTypes;
772+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> landingPadOperands;
773+
llvm::SMLoc landingPadOperandsLoc;
774+
llvm::SmallVector<mlir::Type, 1> landingPadTypes;
775+
711776
// If we cannot parse a string callee, it means this is an indirect call.
712777
if (!parser
713778
.parseOptionalAttribute(calleeAttr, CIRDialect::getCalleeAttrName(),
@@ -729,6 +794,14 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
729794
if (parser.parseRParen())
730795
return mlir::failure();
731796

797+
if (hasDestinationBlocks &&
798+
parseTryCallBranches(parser, result, continueOperands, landingPadOperands,
799+
continueTypes, landingPadTypes, continueOperandsLoc,
800+
landingPadOperandsLoc)
801+
.failed()) {
802+
return ::mlir::failure();
803+
}
804+
732805
if (parser.parseOptionalKeyword("nothrow").succeeded())
733806
result.addAttribute(CIRDialect::getNoThrowAttrName(),
734807
mlir::UnitAttr::get(parser.getContext()));
@@ -761,14 +834,34 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
761834
if (parser.resolveOperands(ops, opsFnTy.getInputs(), opsLoc, result.operands))
762835
return mlir::failure();
763836

837+
if (hasDestinationBlocks) {
838+
// The TryCall ODS layout is: cont, landing_pad, operands.
839+
llvm::copy(::llvm::ArrayRef<int32_t>(
840+
{static_cast<int32_t>(continueOperands.size()),
841+
static_cast<int32_t>(landingPadOperands.size()),
842+
static_cast<int32_t>(ops.size())}),
843+
result.getOrAddProperties<cir::TryCallOp::Properties>()
844+
.operandSegmentSizes.begin());
845+
846+
if (parser.resolveOperands(continueOperands, continueTypes,
847+
continueOperandsLoc, result.operands))
848+
return ::mlir::failure();
849+
850+
if (parser.resolveOperands(landingPadOperands, landingPadTypes,
851+
landingPadOperandsLoc, result.operands))
852+
return ::mlir::failure();
853+
}
854+
764855
return mlir::success();
765856
}
766857

767858
static void printCallCommon(mlir::Operation *op,
768859
mlir::FlatSymbolRefAttr calleeSym,
769860
mlir::Value indirectCallee,
770861
mlir::OpAsmPrinter &printer, bool isNothrow,
771-
cir::SideEffect sideEffect) {
862+
cir::SideEffect sideEffect,
863+
mlir::Block *cont = nullptr,
864+
mlir::Block *landingPad = nullptr) {
772865
printer << ' ';
773866

774867
auto callLikeOp = mlir::cast<cir::CIRCallOpInterface>(op);
@@ -782,8 +875,35 @@ static void printCallCommon(mlir::Operation *op,
782875
assert(indirectCallee);
783876
printer << indirectCallee;
784877
}
878+
785879
printer << "(" << ops << ")";
786880

881+
if (cont) {
882+
assert(landingPad && "expected two successors");
883+
auto tryCall = dyn_cast<cir::TryCallOp>(op);
884+
assert(tryCall && "regular calls do not branch");
885+
printer << ' ' << tryCall.getCont();
886+
if (!tryCall.getContOperands().empty()) {
887+
printer << "(";
888+
printer << tryCall.getContOperands();
889+
printer << ' ' << ":";
890+
printer << ' ';
891+
printer << tryCall.getContOperands().getTypes();
892+
printer << ")";
893+
}
894+
printer << ",";
895+
printer << ' ';
896+
printer << tryCall.getLandingPad();
897+
if (!tryCall.getLandingPadOperands().empty()) {
898+
printer << "(";
899+
printer << tryCall.getLandingPadOperands();
900+
printer << ' ' << ":";
901+
printer << ' ';
902+
printer << tryCall.getLandingPadOperands().getTypes();
903+
printer << ")";
904+
}
905+
}
906+
787907
if (isNothrow)
788908
printer << " nothrow";
789909

@@ -793,10 +913,11 @@ static void printCallCommon(mlir::Operation *op,
793913
printer << ")";
794914
}
795915

796-
printer.printOptionalAttrDict(op->getAttrs(),
797-
{CIRDialect::getCalleeAttrName(),
798-
CIRDialect::getNoThrowAttrName(),
799-
CIRDialect::getSideEffectAttrName()});
916+
llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs = {
917+
CIRDialect::getCalleeAttrName(), CIRDialect::getNoThrowAttrName(),
918+
CIRDialect::getSideEffectAttrName(),
919+
CIRDialect::getOperandSegmentSizesAttrName()};
920+
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
800921

801922
printer << " : ";
802923
printer.printFunctionalType(op->getOperands().getTypes(),
@@ -878,6 +999,70 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
878999
return verifyCallCommInSymbolUses(*this, symbolTable);
8791000
}
8801001

1002+
//===----------------------------------------------------------------------===//
1003+
// TryCallOp
1004+
//===----------------------------------------------------------------------===//
1005+
1006+
mlir::OperandRange cir::TryCallOp::getArgOperands() {
1007+
if (isIndirect())
1008+
return getArgs().drop_front(1);
1009+
return getArgs();
1010+
}
1011+
1012+
mlir::MutableOperandRange cir::TryCallOp::getArgOperandsMutable() {
1013+
mlir::MutableOperandRange args = getArgsMutable();
1014+
if (isIndirect())
1015+
return args.slice(1, args.size() - 1);
1016+
return args;
1017+
}
1018+
1019+
mlir::Value cir::TryCallOp::getIndirectCall() {
1020+
assert(isIndirect());
1021+
return getOperand(0);
1022+
}
1023+
1024+
/// Return the operand at index 'i'.
1025+
Value cir::TryCallOp::getArgOperand(unsigned i) {
1026+
if (isIndirect())
1027+
++i;
1028+
return getOperand(i);
1029+
}
1030+
1031+
/// Return the number of operands.
1032+
unsigned cir::TryCallOp::getNumArgOperands() {
1033+
if (isIndirect())
1034+
return this->getOperation()->getNumOperands() - 1;
1035+
return this->getOperation()->getNumOperands();
1036+
}
1037+
1038+
LogicalResult
1039+
cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1040+
return verifyCallCommInSymbolUses(*this, symbolTable);
1041+
}
1042+
1043+
mlir::ParseResult cir::TryCallOp::parse(mlir::OpAsmParser &parser,
1044+
mlir::OperationState &result) {
1045+
return parseCallCommon(parser, result, /*hasDestinationBlocks=*/true);
1046+
}
1047+
1048+
void cir::TryCallOp::print(::mlir::OpAsmPrinter &p) {
1049+
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
1050+
cir::SideEffect sideEffect = getSideEffect();
1051+
printCallCommon(*this, getCalleeAttr(), indirectCallee, p, getNothrow(),
1052+
sideEffect, getCont(), getLandingPad());
1053+
}
1054+
1055+
mlir::SuccessorOperands cir::TryCallOp::getSuccessorOperands(unsigned index) {
1056+
assert(index < getNumSuccessors() && "invalid successor index");
1057+
if (index == 0)
1058+
return SuccessorOperands(getContOperandsMutable());
1059+
if (index == 1)
1060+
return SuccessorOperands(getLandingPadOperandsMutable());
1061+
1062+
// index == 2
1063+
return SuccessorOperands(getArgOperandsMutable());
1064+
}
1065+
8811066
//===----------------------------------------------------------------------===//
8821067
// ReturnOp
8831068
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)