@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
14951495 if (!callInst->getType ()->isVoidTy ())
14961496 types.push_back (convertType (callInst->getType ()));
14971497
1498- if (!callInst->getCalledFunction ()) {
1499- if (!allowInlineAsm ||
1500- !isa<llvm::InlineAsm>(callInst->getCalledOperand ())) {
1501- FailureOr<Value> called = convertValue (callInst->getCalledOperand ());
1502- if (failed (called))
1503- return failure ();
1504- operands.push_back (*called);
1505- }
1498+ bool isInlineAsm = callInst->isInlineAsm ();
1499+ if (isInlineAsm && !allowInlineAsm)
1500+ return failure ();
1501+
1502+ // Cannot use isIndirectCall() here because we need to handle Constant callees
1503+ // that are not considered indirect calls by LLVM. However, in MLIR, they are
1504+ // treated as indirect calls to constant operands that need to be converted.
1505+ // Skip the callee operand if it's inline assembly, as it's handled separately
1506+ // in InlineAsmOp.
1507+ if (!isa<llvm::Function>(callInst->getCalledOperand ()) && !isInlineAsm) {
1508+ FailureOr<Value> called = convertValue (callInst->getCalledOperand ());
1509+ if (failed (called))
1510+ return failure ();
1511+ operands.push_back (*called);
15061512 }
1513+
15071514 SmallVector<llvm::Value *> args (callInst->args ());
15081515 FailureOr<SmallVector<Value>> arguments = convertValues (args);
15091516 if (failed (arguments))
@@ -1593,23 +1600,21 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
15931600 return success ();
15941601 }
15951602 if (inst->getOpcode () == llvm::Instruction::Call) {
1596- auto *callInst = cast<llvm::CallInst>(inst);
1603+ auto callInst = cast<llvm::CallInst>(inst);
1604+ llvm::Value *calledOperand = callInst->getCalledOperand ();
15971605
15981606 SmallVector<Type> types;
15991607 SmallVector<Value> operands;
16001608 if (failed (convertCallTypeAndOperands (callInst, types, operands,
16011609 /* allowInlineAsm=*/ true )))
16021610 return failure ();
16031611
1604- auto funcTy =
1605- dyn_cast<LLVMFunctionType>(convertType (callInst->getFunctionType ()));
1606- if (!funcTy)
1607- return failure ();
1608-
1609- if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand ())) {
1612+ if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1613+ Type resultTy = convertType (callInst->getType ());
1614+ if (!resultTy)
1615+ return failure ();
16101616 auto callOp = builder.create <InlineAsmOp>(
1611- loc, funcTy.getReturnType (), operands,
1612- builder.getStringAttr (asmI->getAsmString ()),
1617+ loc, resultTy, operands, builder.getStringAttr (asmI->getAsmString ()),
16131618 builder.getStringAttr (asmI->getConstraintString ()),
16141619 /* has_side_effects=*/ true ,
16151620 /* is_align_stack=*/ false , /* asm_dialect=*/ nullptr ,
@@ -1619,27 +1624,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16191624 else
16201625 mapNoResultOp (inst, callOp);
16211626 } else {
1622- CallOp callOp;
1627+ auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
1628+ // Retrieve the real function type. For direct calls, use the callee's
1629+ // function type, as it may differ from the operand type in the case of
1630+ // variadic functions. For indirect calls, use the call function type.
1631+ if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1632+ return convertType (callee->getFunctionType ());
1633+ return convertType (callInst->getFunctionType ());
1634+ }());
1635+
1636+ if (!funcTy)
1637+ return failure ();
16231638
1624- if (llvm::Function *callee = callInst->getCalledFunction ()) {
1625- callOp = builder.create <CallOp>(
1626- loc, funcTy, SymbolRefAttr::get (context, callee->getName ()),
1627- operands);
1628- } else {
1629- callOp = builder.create <CallOp>(loc, funcTy, operands);
1630- }
1639+ auto callOp = [&]() -> CallOp {
1640+ if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
1641+ auto name = SymbolRefAttr::get (context, callee->getName ());
1642+ return builder.create <CallOp>(loc, funcTy, name, operands);
1643+ }
1644+ return builder.create <CallOp>(loc, funcTy, operands);
1645+ }();
1646+
1647+ // Handle function attributes.
16311648 callOp.setCConv (convertCConvFromLLVM (callInst->getCallingConv ()));
16321649 callOp.setTailCallKind (
16331650 convertTailCallKindFromLLVM (callInst->getTailCallKind ()));
16341651 setFastmathFlagsAttr (inst, callOp);
16351652
1636- // Handle function attributes.
1637- if (callInst->hasFnAttr (llvm::Attribute::Convergent))
1638- callOp.setConvergent (true );
1639- if (callInst->hasFnAttr (llvm::Attribute::NoUnwind))
1640- callOp.setNoUnwind (true );
1641- if (callInst->hasFnAttr (llvm::Attribute::WillReturn))
1642- callOp.setWillReturn (true );
1653+ callOp.setConvergent (callInst->isConvergent ());
1654+ callOp.setNoUnwind (callInst->doesNotThrow ());
1655+ callOp.setWillReturn (callInst->hasFnAttr (llvm::Attribute::WillReturn));
16431656
16441657 llvm::MemoryEffects memEffects = callInst->getMemoryEffects ();
16451658 ModRefInfo othermem = convertModRefInfoFromLLVM (
0 commit comments