Skip to content

Commit 1c441a8

Browse files
committed
[MLIR][NFC] Refactor common code for call and invoke attributes import
1 parent a7fc4de commit 1c441a8

File tree

2 files changed

+77
-52
lines changed

2 files changed

+77
-52
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ class ModuleImport {
338338
/// DictionaryAttr for the LLVM dialect.
339339
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
340340
OpBuilder &builder);
341+
/// Converts the attributes attached to `inst` and adds them to the `op`.
342+
LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
343+
/// Converts the attributes attached to `inst` and adds them to the `op`.
344+
LogicalResult convertInvokeAttributes(llvm::InvokeInst *inst, InvokeOp op);
341345
/// Returns the builtin type equivalent to the given LLVM dialect type or
342346
/// nullptr if there is no equivalent. The returned type can be used to create
343347
/// an attribute for a GlobalOp or a ConstantOp.

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
139139
if (iface.isConvertibleInstruction(inst->getOpcode()))
140140
return iface.convertInstruction(odsBuilder, inst, llvmOperands,
141141
moduleImport);
142-
// TODO: Implement the `convertInstruction` hooks in the
143-
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
142+
// TODO: Implement the `convertInstruction` hooks in the
143+
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
144144
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
145145
return failure();
146146
}
@@ -1628,56 +1628,40 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16281628
if (failed(operands))
16291629
return failure();
16301630

1631-
if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1632-
Type resultTy = convertType(callInst->getType());
1633-
if (!resultTy)
1634-
return failure();
1635-
auto callOp = builder.create<InlineAsmOp>(
1636-
loc, resultTy, *operands, builder.getStringAttr(asmI->getAsmString()),
1637-
builder.getStringAttr(asmI->getConstraintString()),
1638-
/*has_side_effects=*/true,
1639-
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
1640-
/*operand_attrs=*/nullptr);
1641-
if (!callInst->getType()->isVoidTy())
1642-
mapValue(inst, callOp.getResult(0));
1643-
else
1644-
mapNoResultOp(inst, callOp);
1645-
} else {
1646-
auto funcTy = convertFunctionType(callInst);
1647-
if (!funcTy)
1648-
return failure();
1631+
auto callOp = [&]() -> FailureOr<Operation *> {
1632+
if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1633+
Type resultTy = convertType(callInst->getType());
1634+
if (!resultTy)
1635+
return failure();
1636+
return builder
1637+
.create<InlineAsmOp>(
1638+
loc, resultTy, *operands,
1639+
builder.getStringAttr(asmI->getAsmString()),
1640+
builder.getStringAttr(asmI->getConstraintString()),
1641+
/*has_side_effects=*/true,
1642+
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
1643+
/*operand_attrs=*/nullptr)
1644+
.getOperation();
1645+
} else {
1646+
LLVMFunctionType funcTy = convertFunctionType(callInst);
1647+
if (!funcTy)
1648+
return failure();
16491649

1650-
FlatSymbolRefAttr calleeName = convertCalleeName(callInst);
1651-
auto callOp = builder.create<CallOp>(loc, funcTy, calleeName, *operands);
1652-
1653-
// Handle function attributes.
1654-
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
1655-
callOp.setTailCallKind(
1656-
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
1657-
setFastmathFlagsAttr(inst, callOp);
1658-
1659-
callOp.setConvergent(callInst->isConvergent());
1660-
callOp.setNoUnwind(callInst->doesNotThrow());
1661-
callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
1662-
1663-
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
1664-
ModRefInfo othermem = convertModRefInfoFromLLVM(
1665-
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
1666-
ModRefInfo argMem = convertModRefInfoFromLLVM(
1667-
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
1668-
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
1669-
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
1670-
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
1671-
argMem, inaccessibleMem);
1672-
// Only set the attribute when it does not match the default value.
1673-
if (!memAttr.isReadWrite())
1674-
callOp.setMemoryEffectsAttr(memAttr);
1675-
1676-
if (!callInst->getType()->isVoidTy())
1677-
mapValue(inst, callOp.getResult());
1678-
else
1679-
mapNoResultOp(inst, callOp);
1680-
}
1650+
FlatSymbolRefAttr callee = convertCalleeName(callInst);
1651+
auto callOp = builder.create<CallOp>(loc, funcTy, callee, *operands);
1652+
if (failed(convertCallAttributes(callInst, callOp)))
1653+
return failure();
1654+
return callOp.getOperation();
1655+
}
1656+
}();
1657+
1658+
if (failed(callOp))
1659+
return failure();
1660+
1661+
if (!callInst->getType()->isVoidTy())
1662+
mapValue(inst, (*callOp)->getResult(0));
1663+
else
1664+
mapNoResultOp(inst, *callOp);
16811665
return success();
16821666
}
16831667
if (inst->getOpcode() == llvm::Instruction::LandingPad) {
@@ -1745,7 +1729,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17451729
loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
17461730
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
17471731

1748-
invokeOp.setCConv(convertCConvFromLLVM(invokeInst->getCallingConv()));
1732+
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
1733+
return failure();
1734+
17491735
if (!invokeInst->getType()->isVoidTy())
17501736
mapValue(inst, invokeOp.getResults().front());
17511737
else
@@ -2098,6 +2084,41 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
20982084
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
20992085
}
21002086

2087+
template <typename Op>
2088+
static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
2089+
op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));
2090+
return success();
2091+
}
2092+
2093+
LogicalResult ModuleImport::convertInvokeAttributes(llvm::InvokeInst *inst,
2094+
InvokeOp op) {
2095+
return convertCallBaseAttributes(inst, op);
2096+
}
2097+
2098+
LogicalResult ModuleImport::convertCallAttributes(llvm::CallInst *inst,
2099+
CallOp op) {
2100+
setFastmathFlagsAttr(inst, op.getOperation());
2101+
op.setTailCallKind(convertTailCallKindFromLLVM(inst->getTailCallKind()));
2102+
op.setConvergent(inst->isConvergent());
2103+
op.setNoUnwind(inst->doesNotThrow());
2104+
op.setWillReturn(inst->hasFnAttr(llvm::Attribute::WillReturn));
2105+
2106+
llvm::MemoryEffects memEffects = inst->getMemoryEffects();
2107+
ModRefInfo othermem = convertModRefInfoFromLLVM(
2108+
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
2109+
ModRefInfo argMem = convertModRefInfoFromLLVM(
2110+
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
2111+
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
2112+
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
2113+
auto memAttr = MemoryEffectsAttr::get(op.getContext(), othermem, argMem,
2114+
inaccessibleMem);
2115+
// Only set the attribute when it does not match the default value.
2116+
if (!memAttr.isReadWrite())
2117+
op.setMemoryEffectsAttr(memAttr);
2118+
2119+
return convertCallBaseAttributes(inst, op);
2120+
}
2121+
21012122
LogicalResult ModuleImport::processFunction(llvm::Function *func) {
21022123
clearRegionState();
21032124

0 commit comments

Comments
 (0)