@@ -948,6 +948,11 @@ static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
948948 return results;
949949}
950950
951+ // / Gets the variadic callee type for a LLVMFunctionType.
952+ static TypeAttr getCallOpVarCalleeType (LLVMFunctionType calleeType) {
953+ return calleeType.isVarArg () ? TypeAttr::get (calleeType) : nullptr ;
954+ }
955+
951956// / Constructs a LLVMFunctionType from MLIR `results` and `args`.
952957static LLVMFunctionType getLLVMFuncType (MLIRContext *context, TypeRange results,
953958 ValueRange args) {
@@ -974,8 +979,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
974979 FlatSymbolRefAttr callee, ValueRange args) {
975980 assert (callee && " expected non-null callee in direct call builder" );
976981 build (builder, state, results,
977- TypeAttr::get ( getLLVMFuncType (builder. getContext (), results , args)) ,
978- callee, args, /* fastmathFlags= */ nullptr , /* branch_weights=*/ nullptr ,
982+ /* var_callee_type= */ nullptr , callee , args, /* fastmathFlags= */ nullptr ,
983+ /* branch_weights=*/ nullptr ,
979984 /* CConv=*/ nullptr , /* TailCallKind=*/ nullptr ,
980985 /* access_groups=*/ nullptr , /* alias_scopes=*/ nullptr ,
981986 /* noalias_scopes=*/ nullptr , /* tbaa=*/ nullptr );
@@ -997,7 +1002,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
9971002 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
9981003 ValueRange args) {
9991004 build (builder, state, getCallOpResultTypes (calleeType),
1000- TypeAttr::get (calleeType), callee, args, /* fastmathFlags=*/ nullptr ,
1005+ getCallOpVarCalleeType (calleeType), callee, args,
1006+ /* fastmathFlags=*/ nullptr ,
10011007 /* branch_weights=*/ nullptr , /* CConv=*/ nullptr ,
10021008 /* TailCallKind=*/ nullptr , /* access_groups=*/ nullptr ,
10031009 /* alias_scopes=*/ nullptr , /* noalias_scopes=*/ nullptr , /* tbaa=*/ nullptr );
@@ -1006,7 +1012,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
10061012void CallOp::build (OpBuilder &builder, OperationState &state,
10071013 LLVMFunctionType calleeType, ValueRange args) {
10081014 build (builder, state, getCallOpResultTypes (calleeType),
1009- TypeAttr::get (calleeType), /* callee=*/ nullptr , args,
1015+ getCallOpVarCalleeType (calleeType),
1016+ /* callee=*/ nullptr , args,
10101017 /* fastmathFlags=*/ nullptr , /* branch_weights=*/ nullptr ,
10111018 /* CConv=*/ nullptr , /* TailCallKind=*/ nullptr ,
10121019 /* access_groups=*/ nullptr , /* alias_scopes=*/ nullptr ,
@@ -1017,7 +1024,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
10171024 ValueRange args) {
10181025 auto calleeType = func.getFunctionType ();
10191026 build (builder, state, getCallOpResultTypes (calleeType),
1020- TypeAttr::get (calleeType), SymbolRefAttr::get (func), args,
1027+ getCallOpVarCalleeType (calleeType), SymbolRefAttr::get (func), args,
10211028 /* fastmathFlags=*/ nullptr , /* branch_weights=*/ nullptr ,
10221029 /* CConv=*/ nullptr , /* TailCallKind=*/ nullptr ,
10231030 /* access_groups=*/ nullptr , /* alias_scopes=*/ nullptr ,
@@ -1076,9 +1083,49 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
10761083 return success ();
10771084}
10781085
1086+ // / Verify that the parameter and return types of the variadic callee type match
1087+ // / the `callOp` argument and result types.
1088+ template <typename OpTy>
1089+ LogicalResult verifyCallOpVarCalleeType (OpTy callOp) {
1090+ std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType ();
1091+ if (!varCalleeType)
1092+ return success ();
1093+
1094+ // Verify the variadic callee type is a variadic function type.
1095+ if (!varCalleeType->isVarArg ())
1096+ return callOp.emitOpError (
1097+ " expected var_callee_type to be a variadic function type" );
1098+
1099+ // Verify the variadic callee type has at most as many parameters as the call
1100+ // has argument operands.
1101+ if (varCalleeType->getNumParams () > callOp.getArgOperands ().size ())
1102+ return callOp.emitOpError (" expected var_callee_type to have at most " )
1103+ << callOp.getArgOperands ().size () << " parameters" ;
1104+
1105+ // Verify the variadic callee type matches the call argument types.
1106+ for (auto [paramType, operand] :
1107+ llvm::zip (varCalleeType->getParams (), callOp.getArgOperands ()))
1108+ if (paramType != operand.getType ())
1109+ return callOp.emitOpError ()
1110+ << " var_callee_type parameter type mismatch: " << paramType
1111+ << " != " << operand.getType ();
1112+
1113+ // Verify the variadic callee type matches the call result type.
1114+ if (!callOp.getNumResults ()) {
1115+ if (!isa<LLVMVoidType>(varCalleeType->getReturnType ()))
1116+ return callOp.emitOpError (" expected var_callee_type to return void" );
1117+ } else {
1118+ if (callOp.getResult ().getType () != varCalleeType->getReturnType ())
1119+ return callOp.emitOpError (" var_callee_type return type mismatch: " )
1120+ << varCalleeType->getReturnType ()
1121+ << " != " << callOp.getResult ().getType ();
1122+ }
1123+ return success ();
1124+ }
1125+
10791126LogicalResult CallOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
1080- if (getNumResults () > 1 )
1081- return emitOpError ( " must have 0 or 1 result " );
1127+ if (failed ( verifyCallOpVarCalleeType (* this )) )
1128+ return failure ( );
10821129
10831130 // Type for the callee, we'll get it differently depending if it is a direct
10841131 // or indirect call.
@@ -1120,8 +1167,8 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
11201167 if (!funcType)
11211168 return emitOpError (" callee does not have a functional type: " ) << fnType;
11221169
1123- if (funcType.isVarArg () && !getCalleeType ())
1124- return emitOpError () << " missing callee type attribute for vararg call" ;
1170+ if (funcType.isVarArg () && !getVarCalleeType ())
1171+ return emitOpError () << " missing var_callee_type attribute for vararg call" ;
11251172
11261173 // Verify that the operand and result types match the callee.
11271174
@@ -1168,14 +1215,6 @@ void CallOp::print(OpAsmPrinter &p) {
11681215 auto callee = getCallee ();
11691216 bool isDirect = callee.has_value ();
11701217
1171- LLVMFunctionType calleeType;
1172- bool isVarArg = false ;
1173-
1174- if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType ()) {
1175- calleeType = *optionalCalleeType;
1176- isVarArg = calleeType.isVarArg ();
1177- }
1178-
11791218 p << ' ' ;
11801219
11811220 // Print calling convention.
@@ -1195,12 +1234,13 @@ void CallOp::print(OpAsmPrinter &p) {
11951234 auto args = getOperands ().drop_front (isDirect ? 0 : 1 );
11961235 p << ' (' << args << ' )' ;
11971236
1198- if (isVarArg)
1199- p << " vararg(" << calleeType << " )" ;
1237+ // Print the variadic callee type if the call is variadic.
1238+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1239+ p << " vararg(" << *varCalleeType << " )" ;
12001240
12011241 p.printOptionalAttrDict (processFMFAttr ((*this )->getAttrs ()),
1202- {getCConvAttrName (), " callee " , " callee_type " ,
1203- getTailCallKindAttrName ()});
1242+ {getCalleeAttrName (), getTailCallKindAttrName () ,
1243+ getVarCalleeTypeAttrName (), getCConvAttrName ()});
12041244
12051245 p << " : " ;
12061246 if (!isDirect)
@@ -1270,11 +1310,11 @@ static ParseResult parseOptionalCallFuncPtr(
12701310
12711311// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
12721312// `(` ssa-use-list `)`
1273- // ( `vararg(` var-arg-func -type `)` )?
1313+ // ( `vararg(` var-callee -type `)` )?
12741314// attribute-dict? `:` (type `,`)? function-type
12751315ParseResult CallOp::parse (OpAsmParser &parser, OperationState &result) {
12761316 SymbolRefAttr funcAttr;
1277- TypeAttr calleeType ;
1317+ TypeAttr varCalleeType ;
12781318 SmallVector<OpAsmParser::UnresolvedOperand> operands;
12791319
12801320 // Default to C Calling Convention if no keyword is provided.
@@ -1305,8 +1345,12 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
13051345
13061346 bool isVarArg = parser.parseOptionalKeyword (" vararg" ).succeeded ();
13071347 if (isVarArg) {
1348+ StringAttr varCalleeTypeAttrName =
1349+ CallOp::getVarCalleeTypeAttrName (result.name );
13081350 if (parser.parseLParen ().failed () ||
1309- parser.parseAttribute (calleeType, " callee_type" , result.attributes )
1351+ parser
1352+ .parseAttribute (varCalleeType, varCalleeTypeAttrName,
1353+ result.attributes )
13101354 .failed () ||
13111355 parser.parseRParen ().failed ())
13121356 return failure ();
@@ -1320,8 +1364,8 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
13201364}
13211365
13221366LLVMFunctionType CallOp::getCalleeFunctionType () {
1323- if (getCalleeType ())
1324- return *getCalleeType () ;
1367+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1368+ return *varCalleeType ;
13251369 return getLLVMFuncType (getContext (), getResultTypes (), getArgOperands ());
13261370}
13271371
@@ -1334,26 +1378,26 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
13341378 Block *unwind, ValueRange unwindOps) {
13351379 auto calleeType = func.getFunctionType ();
13361380 build (builder, state, getCallOpResultTypes (calleeType),
1337- TypeAttr::get (calleeType), SymbolRefAttr::get (func), ops, normalOps ,
1338- unwindOps, nullptr , nullptr , normal, unwind);
1381+ getCallOpVarCalleeType (calleeType), SymbolRefAttr::get (func), ops,
1382+ normalOps, unwindOps, nullptr , nullptr , normal, unwind);
13391383}
13401384
13411385void InvokeOp::build (OpBuilder &builder, OperationState &state, TypeRange tys,
13421386 FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
13431387 ValueRange normalOps, Block *unwind,
13441388 ValueRange unwindOps) {
13451389 build (builder, state, tys,
1346- TypeAttr::get ( getLLVMFuncType (builder. getContext (), tys , ops)), callee ,
1347- ops, normalOps, unwindOps, nullptr , nullptr , normal, unwind);
1390+ /* var_callee_type= */ nullptr , callee , ops, normalOps, unwindOps, nullptr ,
1391+ nullptr , normal, unwind);
13481392}
13491393
13501394void InvokeOp::build (OpBuilder &builder, OperationState &state,
13511395 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
13521396 ValueRange ops, Block *normal, ValueRange normalOps,
13531397 Block *unwind, ValueRange unwindOps) {
13541398 build (builder, state, getCallOpResultTypes (calleeType),
1355- TypeAttr::get (calleeType), callee, ops, normalOps, unwindOps, nullptr ,
1356- nullptr , normal, unwind);
1399+ getCallOpVarCalleeType (calleeType), callee, ops, normalOps, unwindOps,
1400+ nullptr , nullptr , normal, unwind);
13571401}
13581402
13591403SuccessorOperands InvokeOp::getSuccessorOperands (unsigned index) {
@@ -1390,8 +1434,8 @@ MutableOperandRange InvokeOp::getArgOperandsMutable() {
13901434}
13911435
13921436LogicalResult InvokeOp::verify () {
1393- if (getNumResults () > 1 )
1394- return emitOpError ( " must have 0 or 1 result " );
1437+ if (failed ( verifyCallOpVarCalleeType (* this )) )
1438+ return failure ( );
13951439
13961440 Block *unwindDest = getUnwindDest ();
13971441 if (unwindDest->empty ())
@@ -1409,14 +1453,6 @@ void InvokeOp::print(OpAsmPrinter &p) {
14091453 auto callee = getCallee ();
14101454 bool isDirect = callee.has_value ();
14111455
1412- LLVMFunctionType calleeType;
1413- bool isVarArg = false ;
1414-
1415- if (std::optional<LLVMFunctionType> optionalCalleeType = getCalleeType ()) {
1416- calleeType = *optionalCalleeType;
1417- isVarArg = calleeType.isVarArg ();
1418- }
1419-
14201456 p << ' ' ;
14211457
14221458 // Print calling convention.
@@ -1435,12 +1471,13 @@ void InvokeOp::print(OpAsmPrinter &p) {
14351471 p << " unwind " ;
14361472 p.printSuccessorAndUseList (getUnwindDest (), getUnwindDestOperands ());
14371473
1438- if (isVarArg)
1439- p << " vararg(" << calleeType << " )" ;
1474+ // Print the variadic callee type if the invoke is variadic.
1475+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1476+ p << " vararg(" << *varCalleeType << " )" ;
14401477
14411478 p.printOptionalAttrDict ((*this )->getAttrs (),
1442- {InvokeOp::getOperandSegmentSizeAttr (), " callee " ,
1443- " callee_type " , InvokeOp::getCConvAttrName ()});
1479+ {getCalleeAttrName (), getOperandSegmentSizeAttr () ,
1480+ getCConvAttrName (), getVarCalleeTypeAttrName ()});
14441481
14451482 p << " : " ;
14461483 if (!isDirect)
@@ -1453,12 +1490,12 @@ void InvokeOp::print(OpAsmPrinter &p) {
14531490// `(` ssa-use-list `)`
14541491// `to` bb-id (`[` ssa-use-and-type-list `]`)?
14551492// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1456- // ( `vararg(` var-arg-func -type `)` )?
1493+ // ( `vararg(` var-callee -type `)` )?
14571494// attribute-dict? `:` (type `,`)? function-type
14581495ParseResult InvokeOp::parse (OpAsmParser &parser, OperationState &result) {
14591496 SmallVector<OpAsmParser::UnresolvedOperand, 8 > operands;
14601497 SymbolRefAttr funcAttr;
1461- TypeAttr calleeType ;
1498+ TypeAttr varCalleeType ;
14621499 Block *normalDest, *unwindDest;
14631500 SmallVector<Value, 4 > normalOperands, unwindOperands;
14641501 Builder &builder = parser.getBuilder ();
@@ -1488,8 +1525,12 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
14881525
14891526 bool isVarArg = parser.parseOptionalKeyword (" vararg" ).succeeded ();
14901527 if (isVarArg) {
1528+ StringAttr varCalleeTypeAttrName =
1529+ InvokeOp::getVarCalleeTypeAttrName (result.name );
14911530 if (parser.parseLParen ().failed () ||
1492- parser.parseAttribute (calleeType, " callee_type" , result.attributes )
1531+ parser
1532+ .parseAttribute (varCalleeType, varCalleeTypeAttrName,
1533+ result.attributes )
14931534 .failed () ||
14941535 parser.parseRParen ().failed ())
14951536 return failure ();
@@ -1515,8 +1556,8 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
15151556}
15161557
15171558LLVMFunctionType InvokeOp::getCalleeFunctionType () {
1518- if (getCalleeType ())
1519- return *getCalleeType () ;
1559+ if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType ())
1560+ return *varCalleeType ;
15201561 return getLLVMFuncType (getContext (), getResultTypes (), getArgOperands ());
15211562}
15221563
0 commit comments