@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
139139 if (iface.isConvertibleInstruction (inst->getOpcode ()))
140140 return iface.convertInstruction (odsBuilder, inst, llvmOperands,
141141 moduleImport);
142- // TODO: Implement the `convertInstruction` hooks in the
143- // `LLVMDialectLLVMIRImportInterface` and move the following include there.
142+ // TODO: Implement the `convertInstruction` hooks in the
143+ // `LLVMDialectLLVMIRImportInterface` and move the following include there.
144144#include " mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
145145 return failure ();
146146}
@@ -1489,16 +1489,15 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
14891489 return success ();
14901490}
14911491
1492- LogicalResult ModuleImport::convertCallTypeAndOperands (
1493- llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
1494- SmallVectorImpl<Value> &operands, bool allowInlineAsm) {
1495- if (!callInst->getType ()->isVoidTy ())
1496- types.push_back (convertType (callInst->getType ()));
1497-
1492+ FailureOr<SmallVector<Value>>
1493+ ModuleImport::convertCallOperands (llvm::CallBase *callInst,
1494+ bool allowInlineAsm) {
14981495 bool isInlineAsm = callInst->isInlineAsm ();
14991496 if (isInlineAsm && !allowInlineAsm)
15001497 return failure ();
15011498
1499+ SmallVector<Value> operands;
1500+
15021501 // Cannot use isIndirectCall() here because we need to handle Constant callees
15031502 // that are not considered indirect calls by LLVM. However, in MLIR, they are
15041503 // treated as indirect calls to constant operands that need to be converted.
@@ -1515,8 +1514,29 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
15151514 FailureOr<SmallVector<Value>> arguments = convertValues (args);
15161515 if (failed (arguments))
15171516 return failure ();
1517+
15181518 llvm::append_range (operands, *arguments);
1519- return success ();
1519+ return operands;
1520+ }
1521+
1522+ LLVMFunctionType ModuleImport::convertFunctionType (llvm::CallBase *callInst) {
1523+ llvm::Value *calledOperand = callInst->getCalledOperand ();
1524+ Type converted = [&] {
1525+ if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1526+ return convertType (callee->getFunctionType ());
1527+ return convertType (callInst->getFunctionType ());
1528+ }();
1529+
1530+ if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
1531+ return funcTy;
1532+ return {};
1533+ }
1534+
1535+ FlatSymbolRefAttr ModuleImport::convertCalleeName (llvm::CallBase *callInst) {
1536+ llvm::Value *calledOperand = callInst->getCalledOperand ();
1537+ if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1538+ return SymbolRefAttr::get (context, callee->getName ());
1539+ return {};
15201540}
15211541
15221542LogicalResult ModuleImport::convertIntrinsic (llvm::CallInst *inst) {
@@ -1603,75 +1623,45 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16031623 auto callInst = cast<llvm::CallInst>(inst);
16041624 llvm::Value *calledOperand = callInst->getCalledOperand ();
16051625
1606- SmallVector<Type> types;
1607- SmallVector<Value> operands;
1608- if (failed (convertCallTypeAndOperands (callInst, types, operands,
1609- /* allowInlineAsm=*/ true )))
1626+ FailureOr<SmallVector<Value>> operands =
1627+ convertCallOperands (callInst, /* allowInlineAsm=*/ true );
1628+ if (failed (operands))
16101629 return failure ();
16111630
1612- if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1613- Type resultTy = convertType (callInst->getType ());
1614- if (!resultTy)
1615- return failure ();
1616- auto callOp = builder.create <InlineAsmOp>(
1617- loc, resultTy, operands, builder.getStringAttr (asmI->getAsmString ()),
1618- builder.getStringAttr (asmI->getConstraintString ()),
1619- /* has_side_effects=*/ true ,
1620- /* is_align_stack=*/ false , /* asm_dialect=*/ nullptr ,
1621- /* operand_attrs=*/ nullptr );
1622- if (!callInst->getType ()->isVoidTy ())
1623- mapValue (inst, callOp.getResult (0 ));
1624- else
1625- mapNoResultOp (inst, callOp);
1626- } else {
1627- auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
1628- // Retrieve the real function type. For direct calls, use the callee's
1629- // function type, as it may differ from the operand type in the case of
1630- // variadic functions. For indirect calls, use the call function type.
1631- if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1632- return convertType (callee->getFunctionType ());
1633- return convertType (callInst->getFunctionType ());
1634- }());
1635-
1636- if (!funcTy)
1637- return failure ();
1631+ auto callOp = [&]() -> FailureOr<Operation *> {
1632+ if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1633+ Type resultTy = convertType (callInst->getType ());
1634+ if (!resultTy)
1635+ return failure ();
1636+ return builder
1637+ .create <InlineAsmOp>(
1638+ loc, resultTy, *operands,
1639+ builder.getStringAttr (asmI->getAsmString ()),
1640+ builder.getStringAttr (asmI->getConstraintString ()),
1641+ /* has_side_effects=*/ true ,
1642+ /* is_align_stack=*/ false , /* asm_dialect=*/ nullptr ,
1643+ /* operand_attrs=*/ nullptr )
1644+ .getOperation ();
1645+ } else {
1646+ LLVMFunctionType funcTy = convertFunctionType (callInst);
1647+ if (!funcTy)
1648+ return failure ();
16381649
1639- auto callOp = [&]() -> CallOp {
1640- if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
1641- auto name = SymbolRefAttr::get (context, callee->getName ());
1642- return builder.create <CallOp>(loc, funcTy, name, operands);
1643- }
1644- return builder.create <CallOp>(loc, funcTy, operands);
1645- }();
1646-
1647- // Handle function attributes.
1648- callOp.setCConv (convertCConvFromLLVM (callInst->getCallingConv ()));
1649- callOp.setTailCallKind (
1650- convertTailCallKindFromLLVM (callInst->getTailCallKind ()));
1651- setFastmathFlagsAttr (inst, callOp);
1652-
1653- callOp.setConvergent (callInst->isConvergent ());
1654- callOp.setNoUnwind (callInst->doesNotThrow ());
1655- callOp.setWillReturn (callInst->hasFnAttr (llvm::Attribute::WillReturn));
1656-
1657- llvm::MemoryEffects memEffects = callInst->getMemoryEffects ();
1658- ModRefInfo othermem = convertModRefInfoFromLLVM (
1659- memEffects.getModRef (llvm::MemoryEffects::Location::Other));
1660- ModRefInfo argMem = convertModRefInfoFromLLVM (
1661- memEffects.getModRef (llvm::MemoryEffects::Location::ArgMem));
1662- ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM (
1663- memEffects.getModRef (llvm::MemoryEffects::Location::InaccessibleMem));
1664- auto memAttr = MemoryEffectsAttr::get (callOp.getContext (), othermem,
1665- argMem, inaccessibleMem);
1666- // Only set the attribute when it does not match the default value.
1667- if (!memAttr.isReadWrite ())
1668- callOp.setMemoryEffectsAttr (memAttr);
1669-
1670- if (!callInst->getType ()->isVoidTy ())
1671- mapValue (inst, callOp.getResult ());
1672- else
1673- mapNoResultOp (inst, callOp);
1674- }
1650+ FlatSymbolRefAttr callee = convertCalleeName (callInst);
1651+ auto callOp = builder.create <CallOp>(loc, funcTy, callee, *operands);
1652+ if (failed (convertCallAttributes (callInst, callOp)))
1653+ return failure ();
1654+ return callOp.getOperation ();
1655+ }
1656+ }();
1657+
1658+ if (failed (callOp))
1659+ return failure ();
1660+
1661+ if (!callInst->getType ()->isVoidTy ())
1662+ mapValue (inst, (*callOp)->getResult (0 ));
1663+ else
1664+ mapNoResultOp (inst, *callOp);
16751665 return success ();
16761666 }
16771667 if (inst->getOpcode () == llvm::Instruction::LandingPad) {
@@ -1695,9 +1685,11 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16951685 if (inst->getOpcode () == llvm::Instruction::Invoke) {
16961686 auto *invokeInst = cast<llvm::InvokeInst>(inst);
16971687
1698- SmallVector<Type> types;
1699- SmallVector<Value> operands;
1700- if (failed (convertCallTypeAndOperands (invokeInst, types, operands)))
1688+ if (invokeInst->isInlineAsm ())
1689+ return emitError (loc) << " invoke of inline assembly is not supported" ;
1690+
1691+ FailureOr<SmallVector<Value>> operands = convertCallOperands (invokeInst);
1692+ if (failed (operands))
17011693 return failure ();
17021694
17031695 // Check whether the invoke result is an argument to the normal destination
@@ -1724,27 +1716,22 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17241716 unwindArgs)))
17251717 return failure ();
17261718
1727- auto funcTy =
1728- dyn_cast<LLVMFunctionType>(convertType (invokeInst->getFunctionType ()));
1719+ auto funcTy = convertFunctionType (invokeInst);
17291720 if (!funcTy)
17301721 return failure ();
17311722
1723+ FlatSymbolRefAttr calleeName = convertCalleeName (invokeInst);
1724+
17321725 // Create the invoke operation. Normal destination block arguments will be
17331726 // added later on to handle the case in which the operation result is
17341727 // included in this list.
1735- InvokeOp invokeOp;
1736- if (llvm::Function *callee = invokeInst->getCalledFunction ()) {
1737- invokeOp = builder.create <InvokeOp>(
1738- loc, funcTy,
1739- SymbolRefAttr::get (builder.getContext (), callee->getName ()), operands,
1740- directNormalDest, ValueRange (),
1741- lookupBlock (invokeInst->getUnwindDest ()), unwindArgs);
1742- } else {
1743- invokeOp = builder.create <InvokeOp>(
1744- loc, funcTy, /* callee=*/ nullptr , operands, directNormalDest,
1745- ValueRange (), lookupBlock (invokeInst->getUnwindDest ()), unwindArgs);
1746- }
1747- invokeOp.setCConv (convertCConvFromLLVM (invokeInst->getCallingConv ()));
1728+ auto invokeOp = builder.create <InvokeOp>(
1729+ loc, funcTy, calleeName, *operands, directNormalDest, ValueRange (),
1730+ lookupBlock (invokeInst->getUnwindDest ()), unwindArgs);
1731+
1732+ if (failed (convertInvokeAttributes (invokeInst, invokeOp)))
1733+ return failure ();
1734+
17481735 if (!invokeInst->getType ()->isVoidTy ())
17491736 mapValue (inst, invokeOp.getResults ().front ());
17501737 else
@@ -2097,6 +2084,41 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
20972084 builder.getArrayAttr (convertParameterAttribute (llvmResAttr, builder)));
20982085}
20992086
2087+ template <typename Op>
2088+ static LogicalResult convertCallBaseAttributes (llvm::CallBase *inst, Op op) {
2089+ op.setCConv (convertCConvFromLLVM (inst->getCallingConv ()));
2090+ return success ();
2091+ }
2092+
2093+ LogicalResult ModuleImport::convertInvokeAttributes (llvm::InvokeInst *inst,
2094+ InvokeOp op) {
2095+ return convertCallBaseAttributes (inst, op);
2096+ }
2097+
2098+ LogicalResult ModuleImport::convertCallAttributes (llvm::CallInst *inst,
2099+ CallOp op) {
2100+ setFastmathFlagsAttr (inst, op.getOperation ());
2101+ op.setTailCallKind (convertTailCallKindFromLLVM (inst->getTailCallKind ()));
2102+ op.setConvergent (inst->isConvergent ());
2103+ op.setNoUnwind (inst->doesNotThrow ());
2104+ op.setWillReturn (inst->hasFnAttr (llvm::Attribute::WillReturn));
2105+
2106+ llvm::MemoryEffects memEffects = inst->getMemoryEffects ();
2107+ ModRefInfo othermem = convertModRefInfoFromLLVM (
2108+ memEffects.getModRef (llvm::MemoryEffects::Location::Other));
2109+ ModRefInfo argMem = convertModRefInfoFromLLVM (
2110+ memEffects.getModRef (llvm::MemoryEffects::Location::ArgMem));
2111+ ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM (
2112+ memEffects.getModRef (llvm::MemoryEffects::Location::InaccessibleMem));
2113+ auto memAttr = MemoryEffectsAttr::get (op.getContext (), othermem, argMem,
2114+ inaccessibleMem);
2115+ // Only set the attribute when it does not match the default value.
2116+ if (!memAttr.isReadWrite ())
2117+ op.setMemoryEffectsAttr (memAttr);
2118+
2119+ return convertCallBaseAttributes (inst, op);
2120+ }
2121+
21002122LogicalResult ModuleImport::processFunction (llvm::Function *func) {
21012123 clearRegionState ();
21022124
0 commit comments