@@ -1519,22 +1519,72 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
15191519 return operands;
15201520}
15211521
1522- LLVMFunctionType ModuleImport::convertFunctionType (llvm::CallBase *callInst) {
1523- llvm::Value *calledOperand = callInst->getCalledOperand ();
1524- Type converted = [&] {
1525- if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1526- return convertType (callee->getFunctionType ());
1527- return convertType (callInst->getFunctionType ());
1528- }();
1522+ // / Checks if `callType` and `calleeType` are compatible and can be represented
1523+ // / in MLIR.
1524+ static LogicalResult
1525+ verifyFunctionTypeCompatibility (LLVMFunctionType callType,
1526+ LLVMFunctionType calleeType) {
1527+ if (callType.getReturnType () != calleeType.getReturnType ())
1528+ return failure ();
1529+
1530+ if (calleeType.isVarArg ()) {
1531+ // For variadic functions, the call can have more types than the callee
1532+ // specifies.
1533+ if (callType.getNumParams () < calleeType.getNumParams ())
1534+ return failure ();
1535+ } else {
1536+ // For non-variadic functions, the number of parameters needs to be the
1537+ // same.
1538+ if (callType.getNumParams () != calleeType.getNumParams ())
1539+ return failure ();
1540+ }
1541+
1542+ // Check that all operands match.
1543+ for (auto [operandType, argumentType] :
1544+ llvm::zip (callType.getParams (), calleeType.getParams ()))
1545+ if (operandType != argumentType)
1546+ return failure ();
1547+
1548+ return success ();
1549+ }
15291550
1530- if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
1551+ FailureOr<LLVMFunctionType>
1552+ ModuleImport::convertFunctionType (llvm::CallBase *callInst) {
1553+ auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
1554+ auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
1555+ if (!funcTy)
1556+ return failure ();
15311557 return funcTy;
1532- return {};
1558+ };
1559+
1560+ llvm::Value *calledOperand = callInst->getCalledOperand ();
1561+ FailureOr<LLVMFunctionType> callType =
1562+ castOrFailure (convertType (callInst->getFunctionType ()));
1563+ if (failed (callType))
1564+ return failure ();
1565+ auto *callee = dyn_cast<llvm::Function>(calledOperand);
1566+ // For indirect calls, return the type of the call itself.
1567+ if (!callee)
1568+ return callType;
1569+
1570+ FailureOr<LLVMFunctionType> calleeType =
1571+ castOrFailure (convertType (callee->getFunctionType ()));
1572+ if (failed (calleeType))
1573+ return failure ();
1574+
1575+ // Compare the types to avoid constructing illegal call/invoke operations.
1576+ if (failed (verifyFunctionTypeCompatibility (*callType, *calleeType))) {
1577+ Location loc = translateLoc (callInst->getDebugLoc ());
1578+ return emitError (loc) << " incompatible call and callee types: " << *callType
1579+ << " and " << *calleeType;
1580+ }
1581+
1582+ return calleeType;
15331583}
15341584
15351585FlatSymbolRefAttr ModuleImport::convertCalleeName (llvm::CallBase *callInst) {
15361586 llvm::Value *calledOperand = callInst->getCalledOperand ();
1537- if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1587+ if (auto * callee = dyn_cast<llvm::Function>(calledOperand))
15381588 return SymbolRefAttr::get (context, callee->getName ());
15391589 return {};
15401590}
@@ -1620,7 +1670,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16201670 return success ();
16211671 }
16221672 if (inst->getOpcode () == llvm::Instruction::Call) {
1623- auto callInst = cast<llvm::CallInst>(inst);
1673+ auto * callInst = cast<llvm::CallInst>(inst);
16241674 llvm::Value *calledOperand = callInst->getCalledOperand ();
16251675
16261676 FailureOr<SmallVector<Value>> operands =
@@ -1629,7 +1679,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16291679 return failure ();
16301680
16311681 auto callOp = [&]() -> FailureOr<Operation *> {
1632- if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1682+ if (auto * asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
16331683 Type resultTy = convertType (callInst->getType ());
16341684 if (!resultTy)
16351685 return failure ();
@@ -1642,17 +1692,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16421692 /* is_align_stack=*/ false , /* asm_dialect=*/ nullptr ,
16431693 /* operand_attrs=*/ nullptr )
16441694 .getOperation ();
1645- } else {
1646- LLVMFunctionType funcTy = convertFunctionType (callInst);
1647- if (!funcTy)
1648- return failure ();
1649-
1650- FlatSymbolRefAttr callee = convertCalleeName (callInst);
1651- auto callOp = builder.create <CallOp>(loc, funcTy, callee, *operands);
1652- if (failed (convertCallAttributes (callInst, callOp)))
1653- return failure ();
1654- return callOp.getOperation ();
16551695 }
1696+ FailureOr<LLVMFunctionType> funcTy = convertFunctionType (callInst);
1697+ if (failed (funcTy))
1698+ return failure ();
1699+
1700+ FlatSymbolRefAttr callee = convertCalleeName (callInst);
1701+ auto callOp = builder.create <CallOp>(loc, *funcTy, callee, *operands);
1702+ if (failed (convertCallAttributes (callInst, callOp)))
1703+ return failure ();
1704+ return callOp.getOperation ();
16561705 }();
16571706
16581707 if (failed (callOp))
@@ -1716,8 +1765,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17161765 unwindArgs)))
17171766 return failure ();
17181767
1719- auto funcTy = convertFunctionType (invokeInst);
1720- if (! funcTy)
1768+ FailureOr<LLVMFunctionType> funcTy = convertFunctionType (invokeInst);
1769+ if (failed ( funcTy) )
17211770 return failure ();
17221771
17231772 FlatSymbolRefAttr calleeName = convertCalleeName (invokeInst);
@@ -1726,7 +1775,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17261775 // added later on to handle the case in which the operation result is
17271776 // included in this list.
17281777 auto invokeOp = builder.create <InvokeOp>(
1729- loc, funcTy, calleeName, *operands, directNormalDest, ValueRange (),
1778+ loc, * funcTy, calleeName, *operands, directNormalDest, ValueRange (),
17301779 lookupBlock (invokeInst->getUnwindDest ()), unwindArgs);
17311780
17321781 if (failed (convertInvokeAttributes (invokeInst, invokeOp)))
0 commit comments