@@ -1335,55 +1335,66 @@ void CallOp::print(OpAsmPrinter &p) {
13351335 getVarCalleeTypeAttrName (), getCConvAttrName (),
13361336 getOperandSegmentSizesAttrName (),
13371337 getOpBundleSizesAttrName (),
1338- getOpBundleTagsAttrName ()});
1338+ getOpBundleTagsAttrName (), getArgAttrsAttrName (),
1339+ getResAttrsAttrName ()});
13391340
13401341 p << " : " ;
13411342 if (!isDirect)
13421343 p << getOperand (0 ).getType () << " , " ;
13431344
1344- // Reconstruct the function MLIR function type from operand and result types.
1345- p.printFunctionalType (args.getTypes (), getResultTypes ());
1345+ // Reconstruct the MLIR function type from operand and result types.
1346+ call_interface_impl::printFunctionSignature (
1347+ p, args.getTypes (), getArgAttrsAttr (),
1348+ /* isVariadic=*/ false , getResultTypes (), getResAttrsAttr ());
13461349}
13471350
13481351// / Parses the type of a call operation and resolves the operands if the parsing
13491352// / succeeds. Returns failure otherwise.
13501353static ParseResult parseCallTypeAndResolveOperands (
13511354 OpAsmParser &parser, OperationState &result, bool isDirect,
1352- ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
1355+ ArrayRef<OpAsmParser::UnresolvedOperand> operands,
1356+ SmallVectorImpl<DictionaryAttr> &argAttrs,
1357+ SmallVectorImpl<DictionaryAttr> &resultAttrs) {
13531358 SMLoc trailingTypesLoc = parser.getCurrentLocation ();
13541359 SmallVector<Type> types;
1355- if (parser.parseColonTypeList (types ))
1360+ if (parser.parseColon ( ))
13561361 return failure ();
1357-
1358- if (isDirect && types.size () != 1 )
1359- return parser.emitError (trailingTypesLoc,
1360- " expected direct call to have 1 trailing type" );
1361- if (!isDirect && types.size () != 2 )
1362- return parser.emitError (trailingTypesLoc,
1363- " expected indirect call to have 2 trailing types" );
1364-
1365- auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val ());
1366- if (!funcType)
1362+ if (!isDirect) {
1363+ types.emplace_back ();
1364+ if (parser.parseType (types.back ()))
1365+ return failure ();
1366+ if (parser.parseOptionalComma ())
1367+ return parser.emitError (
1368+ trailingTypesLoc, " expected indirect call to have 2 trailing types" );
1369+ }
1370+ SmallVector<Type> argTypes;
1371+ SmallVector<Type> resTypes;
1372+ if (call_interface_impl::parseFunctionSignature (parser, argTypes, argAttrs,
1373+ resTypes, resultAttrs)) {
1374+ if (isDirect)
1375+ return parser.emitError (trailingTypesLoc,
1376+ " expected direct call to have 1 trailing types" );
13671377 return parser.emitError (trailingTypesLoc,
13681378 " expected trailing function type" );
1369- if (funcType.getNumResults () > 1 )
1379+ }
1380+
1381+ if (resTypes.size () > 1 )
13701382 return parser.emitError (trailingTypesLoc,
13711383 " expected function with 0 or 1 result" );
1372- if (funcType.getNumResults () == 1 &&
1373- llvm::isa<LLVM::LLVMVoidType>(funcType.getResult (0 )))
1384+ if (resTypes.size () == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0 ]))
13741385 return parser.emitError (trailingTypesLoc,
13751386 " expected a non-void result type" );
13761387
13771388 // The head element of the types list matches the callee type for
13781389 // indirect calls, while the types list is emtpy for direct calls.
13791390 // Append the function input types to resolve the call operation
13801391 // operands.
1381- llvm::append_range (types, funcType. getInputs () );
1392+ llvm::append_range (types, argTypes );
13821393 if (parser.resolveOperands (operands, types, parser.getNameLoc (),
13831394 result.operands ))
13841395 return failure ();
1385- if (funcType. getNumResults () != 0 )
1386- result.addTypes (funcType. getResults () );
1396+ if (resTypes. size () != 0 )
1397+ result.addTypes (resTypes );
13871398
13881399 return success ();
13891400}
@@ -1497,8 +1508,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
14971508 return failure ();
14981509
14991510 // Parse the trailing type list and resolve the operands.
1500- if (parseCallTypeAndResolveOperands (parser, result, isDirect, operands))
1511+ SmallVector<DictionaryAttr> argAttrs;
1512+ SmallVector<DictionaryAttr> resultAttrs;
1513+ if (parseCallTypeAndResolveOperands (parser, result, isDirect, operands,
1514+ argAttrs, resultAttrs))
15011515 return failure ();
1516+ call_interface_impl::addArgAndResultAttrs (
1517+ parser.getBuilder (), result, argAttrs, resultAttrs,
1518+ getArgAttrsAttrName (result.name ), getResAttrsAttrName (result.name ));
15021519 if (resolveOpBundleOperands (parser, opBundlesLoc, result, opBundleOperands,
15031520 opBundleOperandTypes,
15041521 getOpBundleSizesAttrName (result.name )))
@@ -1643,14 +1660,16 @@ void InvokeOp::print(OpAsmPrinter &p) {
16431660 {getCalleeAttrName (), getOperandSegmentSizeAttr (),
16441661 getCConvAttrName (), getVarCalleeTypeAttrName (),
16451662 getOpBundleSizesAttrName (),
1646- getOpBundleTagsAttrName ()});
1663+ getOpBundleTagsAttrName (), getArgAttrsAttrName (),
1664+ getResAttrsAttrName ()});
16471665
16481666 p << " : " ;
16491667 if (!isDirect)
16501668 p << getOperand (0 ).getType () << " , " ;
1651- p.printFunctionalType (
1652- llvm::drop_begin (getCalleeOperands ().getTypes (), isDirect ? 0 : 1 ),
1653- getResultTypes ());
1669+ call_interface_impl::printFunctionSignature (
1670+ p, getCalleeOperands ().drop_front (isDirect ? 0 : 1 ).getTypes (),
1671+ getArgAttrsAttr (),
1672+ /* isVariadic=*/ false , getResultTypes (), getResAttrsAttr ());
16541673}
16551674
16561675// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
@@ -1659,7 +1678,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
16591678// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
16601679// ( `vararg(` var-callee-type `)` )?
16611680// ( `[` op-bundles-list `]` )?
1662- // attribute-dict? `:` (type `,`)? function-type
1681+ // attribute-dict? `:` (type `,`)?
1682+ // function-type-with-argument-attributes
16631683ParseResult InvokeOp::parse (OpAsmParser &parser, OperationState &result) {
16641684 SmallVector<OpAsmParser::UnresolvedOperand, 8 > operands;
16651685 SymbolRefAttr funcAttr;
@@ -1721,8 +1741,15 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
17211741 return failure ();
17221742
17231743 // Parse the trailing type list and resolve the function operands.
1724- if (parseCallTypeAndResolveOperands (parser, result, isDirect, operands))
1744+ SmallVector<DictionaryAttr> argAttrs;
1745+ SmallVector<DictionaryAttr> resultAttrs;
1746+ if (parseCallTypeAndResolveOperands (parser, result, isDirect, operands,
1747+ argAttrs, resultAttrs))
17251748 return failure ();
1749+ call_interface_impl::addArgAndResultAttrs (
1750+ parser.getBuilder (), result, argAttrs, resultAttrs,
1751+ getArgAttrsAttrName (result.name ), getResAttrsAttrName (result.name ));
1752+
17261753 if (resolveOpBundleOperands (parser, opBundlesLoc, result, opBundleOperands,
17271754 opBundleOperandTypes,
17281755 getOpBundleSizesAttrName (result.name )))
0 commit comments