@@ -1335,55 +1335,66 @@ void CallOp::print(OpAsmPrinter &p) {
1335
1335
getVarCalleeTypeAttrName (), getCConvAttrName (),
1336
1336
getOperandSegmentSizesAttrName (),
1337
1337
getOpBundleSizesAttrName (),
1338
- getOpBundleTagsAttrName ()});
1338
+ getOpBundleTagsAttrName (), getArgAttrsAttrName (),
1339
+ getResAttrsAttrName ()});
1339
1340
1340
1341
p << " : " ;
1341
1342
if (!isDirect)
1342
1343
p << getOperand (0 ).getType () << " , " ;
1343
1344
1344
- // Reconstruct the function MLIR function type from operand and result types.
1345
- p.printFunctionalType (args.getTypes (), getResultTypes ());
1345
+ // Reconstruct the MLIR function type from operand and result types.
1346
+ call_interface_impl::printFunctionSignature (
1347
+ p, args.getTypes (), getArgAttrsAttr (),
1348
+ /* isVariadic=*/ false , getResultTypes (), getResAttrsAttr ());
1346
1349
}
1347
1350
1348
1351
// / Parses the type of a call operation and resolves the operands if the parsing
1349
1352
// / succeeds. Returns failure otherwise.
1350
1353
static ParseResult parseCallTypeAndResolveOperands (
1351
1354
OpAsmParser &parser, OperationState &result, bool isDirect,
1352
- ArrayRef<OpAsmParser::UnresolvedOperand> operands) {
1355
+ ArrayRef<OpAsmParser::UnresolvedOperand> operands,
1356
+ SmallVectorImpl<DictionaryAttr> &argAttrs,
1357
+ SmallVectorImpl<DictionaryAttr> &resultAttrs) {
1353
1358
SMLoc trailingTypesLoc = parser.getCurrentLocation ();
1354
1359
SmallVector<Type> types;
1355
- if (parser.parseColonTypeList (types ))
1360
+ if (parser.parseColon ( ))
1356
1361
return failure ();
1357
-
1358
- if (isDirect && types.size () != 1 )
1359
- return parser.emitError (trailingTypesLoc,
1360
- " expected direct call to have 1 trailing type" );
1361
- if (!isDirect && types.size () != 2 )
1362
- return parser.emitError (trailingTypesLoc,
1363
- " expected indirect call to have 2 trailing types" );
1364
-
1365
- auto funcType = llvm::dyn_cast<FunctionType>(types.pop_back_val ());
1366
- if (!funcType)
1362
+ if (!isDirect) {
1363
+ types.emplace_back ();
1364
+ if (parser.parseType (types.back ()))
1365
+ return failure ();
1366
+ if (parser.parseOptionalComma ())
1367
+ return parser.emitError (
1368
+ trailingTypesLoc, " expected indirect call to have 2 trailing types" );
1369
+ }
1370
+ SmallVector<Type> argTypes;
1371
+ SmallVector<Type> resTypes;
1372
+ if (call_interface_impl::parseFunctionSignature (parser, argTypes, argAttrs,
1373
+ resTypes, resultAttrs)) {
1374
+ if (isDirect)
1375
+ return parser.emitError (trailingTypesLoc,
1376
+ " expected direct call to have 1 trailing types" );
1367
1377
return parser.emitError (trailingTypesLoc,
1368
1378
" expected trailing function type" );
1369
- if (funcType.getNumResults () > 1 )
1379
+ }
1380
+
1381
+ if (resTypes.size () > 1 )
1370
1382
return parser.emitError (trailingTypesLoc,
1371
1383
" expected function with 0 or 1 result" );
1372
- if (funcType.getNumResults () == 1 &&
1373
- llvm::isa<LLVM::LLVMVoidType>(funcType.getResult (0 )))
1384
+ if (resTypes.size () == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0 ]))
1374
1385
return parser.emitError (trailingTypesLoc,
1375
1386
" expected a non-void result type" );
1376
1387
1377
1388
// The head element of the types list matches the callee type for
1378
1389
// indirect calls, while the types list is emtpy for direct calls.
1379
1390
// Append the function input types to resolve the call operation
1380
1391
// operands.
1381
- llvm::append_range (types, funcType. getInputs () );
1392
+ llvm::append_range (types, argTypes );
1382
1393
if (parser.resolveOperands (operands, types, parser.getNameLoc (),
1383
1394
result.operands ))
1384
1395
return failure ();
1385
- if (funcType. getNumResults () != 0 )
1386
- result.addTypes (funcType. getResults () );
1396
+ if (resTypes. size () != 0 )
1397
+ result.addTypes (resTypes );
1387
1398
1388
1399
return success ();
1389
1400
}
@@ -1497,8 +1508,14 @@ ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1497
1508
return failure ();
1498
1509
1499
1510
// Parse the trailing type list and resolve the operands.
1500
- if (parseCallTypeAndResolveOperands (parser, result, isDirect, operands))
1511
+ SmallVector<DictionaryAttr> argAttrs;
1512
+ SmallVector<DictionaryAttr> resultAttrs;
1513
+ if (parseCallTypeAndResolveOperands (parser, result, isDirect, operands,
1514
+ argAttrs, resultAttrs))
1501
1515
return failure ();
1516
+ call_interface_impl::addArgAndResultAttrs (
1517
+ parser.getBuilder (), result, argAttrs, resultAttrs,
1518
+ getArgAttrsAttrName (result.name ), getResAttrsAttrName (result.name ));
1502
1519
if (resolveOpBundleOperands (parser, opBundlesLoc, result, opBundleOperands,
1503
1520
opBundleOperandTypes,
1504
1521
getOpBundleSizesAttrName (result.name )))
@@ -1643,14 +1660,16 @@ void InvokeOp::print(OpAsmPrinter &p) {
1643
1660
{getCalleeAttrName (), getOperandSegmentSizeAttr (),
1644
1661
getCConvAttrName (), getVarCalleeTypeAttrName (),
1645
1662
getOpBundleSizesAttrName (),
1646
- getOpBundleTagsAttrName ()});
1663
+ getOpBundleTagsAttrName (), getArgAttrsAttrName (),
1664
+ getResAttrsAttrName ()});
1647
1665
1648
1666
p << " : " ;
1649
1667
if (!isDirect)
1650
1668
p << getOperand (0 ).getType () << " , " ;
1651
- p.printFunctionalType (
1652
- llvm::drop_begin (getCalleeOperands ().getTypes (), isDirect ? 0 : 1 ),
1653
- getResultTypes ());
1669
+ call_interface_impl::printFunctionSignature (
1670
+ p, getCalleeOperands ().drop_front (isDirect ? 0 : 1 ).getTypes (),
1671
+ getArgAttrsAttr (),
1672
+ /* isVariadic=*/ false , getResultTypes (), getResAttrsAttr ());
1654
1673
}
1655
1674
1656
1675
// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
@@ -1659,7 +1678,8 @@ void InvokeOp::print(OpAsmPrinter &p) {
1659
1678
// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1660
1679
// ( `vararg(` var-callee-type `)` )?
1661
1680
// ( `[` op-bundles-list `]` )?
1662
- // attribute-dict? `:` (type `,`)? function-type
1681
+ // attribute-dict? `:` (type `,`)?
1682
+ // function-type-with-argument-attributes
1663
1683
ParseResult InvokeOp::parse (OpAsmParser &parser, OperationState &result) {
1664
1684
SmallVector<OpAsmParser::UnresolvedOperand, 8 > operands;
1665
1685
SymbolRefAttr funcAttr;
@@ -1721,8 +1741,15 @@ ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1721
1741
return failure ();
1722
1742
1723
1743
// Parse the trailing type list and resolve the function operands.
1724
- if (parseCallTypeAndResolveOperands (parser, result, isDirect, operands))
1744
+ SmallVector<DictionaryAttr> argAttrs;
1745
+ SmallVector<DictionaryAttr> resultAttrs;
1746
+ if (parseCallTypeAndResolveOperands (parser, result, isDirect, operands,
1747
+ argAttrs, resultAttrs))
1725
1748
return failure ();
1749
+ call_interface_impl::addArgAndResultAttrs (
1750
+ parser.getBuilder (), result, argAttrs, resultAttrs,
1751
+ getArgAttrsAttrName (result.name ), getResAttrsAttrName (result.name ));
1752
+
1726
1753
if (resolveOpBundleOperands (parser, opBundlesLoc, result, opBundleOperands,
1727
1754
opBundleOperandTypes,
1728
1755
getOpBundleSizesAttrName (result.name )))
0 commit comments