Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,24 +316,32 @@ class ModuleImport {
LogicalResult convertBranchArgs(llvm::Instruction *branch,
llvm::BasicBlock *target,
SmallVectorImpl<Value> &blockArguments);
/// Appends the converted result type and operands of `callInst` to the
/// `types` and `operands` arrays. For indirect calls, the method additionally
/// inserts the called function at the beginning of the `operands` array.
/// If `allowInlineAsm` is set to false (the default), it will return failure
/// if the called operand is an inline asm which isn't convertible to MLIR as
/// a value.
LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
SmallVectorImpl<Type> &types,
SmallVectorImpl<Value> &operands,
bool allowInlineAsm = false);
/// Converts the parameter attributes attached to `func` and adds them to the
/// `funcOp`.
/// Convert `callInst` operands. For indirect calls, the method additionally
/// inserts the called function at the beginning of the returned `operands`
/// array. If `allowInlineAsm` is set to false (the default), it will return
/// failure if the called operand is an inline asm which isn't convertible to
/// MLIR as a value.
FailureOr<SmallVector<Value>>
convertCallOperands(llvm::CallBase *callInst, bool allowInlineAsm = false);
/// Converts the callee's function type. For direct calls, it converts the
/// actual function type, which may differ from the called operand type in
/// variadic functions. For indirect calls, it converts the function type
/// associated with the call instruction.
LLVMFunctionType convertFunctionType(llvm::CallBase *callInst);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
/// Converts the parameter attributes attached to `func` and adds them to
/// the `funcOp`.
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
OpBuilder &builder);
/// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
/// DictionaryAttr for the LLVM dialect.
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
OpBuilder &builder);
/// Converts the attributes attached to `inst` and adds them to the `op`.
LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
/// Converts the attributes attached to `inst` and adds them to the `op`.
LogicalResult convertInvokeAttributes(llvm::InvokeInst *inst, InvokeOp op);
/// Returns the builtin type equivalent to the given LLVM dialect type or
/// nullptr if there is no equivalent. The returned type can be used to create
/// an attribute for a GlobalOp or a ConstantOp.
Expand Down
208 changes: 115 additions & 93 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
if (iface.isConvertibleInstruction(inst->getOpcode()))
return iface.convertInstruction(odsBuilder, inst, llvmOperands,
moduleImport);
// TODO: Implement the `convertInstruction` hooks in the
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
// TODO: Implement the `convertInstruction` hooks in the
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
return failure();
}
Expand Down Expand Up @@ -1489,16 +1489,15 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
return success();
}

LogicalResult ModuleImport::convertCallTypeAndOperands(
llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
SmallVectorImpl<Value> &operands, bool allowInlineAsm) {
if (!callInst->getType()->isVoidTy())
types.push_back(convertType(callInst->getType()));

FailureOr<SmallVector<Value>>
ModuleImport::convertCallOperands(llvm::CallBase *callInst,
bool allowInlineAsm) {
bool isInlineAsm = callInst->isInlineAsm();
if (isInlineAsm && !allowInlineAsm)
return failure();

SmallVector<Value> operands;

// Cannot use isIndirectCall() here because we need to handle Constant callees
// that are not considered indirect calls by LLVM. However, in MLIR, they are
// treated as indirect calls to constant operands that need to be converted.
Expand All @@ -1515,8 +1514,29 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
FailureOr<SmallVector<Value>> arguments = convertValues(args);
if (failed(arguments))
return failure();

llvm::append_range(operands, *arguments);
return success();
return operands;
}

LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
llvm::Value *calledOperand = callInst->getCalledOperand();
Type converted = [&] {
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
return convertType(callee->getFunctionType());
return convertType(callInst->getFunctionType());
}();

if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
return funcTy;
return {};
}

FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
llvm::Value *calledOperand = callInst->getCalledOperand();
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
return SymbolRefAttr::get(context, callee->getName());
return {};
}

LogicalResult ModuleImport::convertIntrinsic(llvm::CallInst *inst) {
Expand Down Expand Up @@ -1603,75 +1623,45 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
auto callInst = cast<llvm::CallInst>(inst);
llvm::Value *calledOperand = callInst->getCalledOperand();

SmallVector<Type> types;
SmallVector<Value> operands;
if (failed(convertCallTypeAndOperands(callInst, types, operands,
/*allowInlineAsm=*/true)))
FailureOr<SmallVector<Value>> operands =
convertCallOperands(callInst, /*allowInlineAsm=*/true);
if (failed(operands))
return failure();

if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
Type resultTy = convertType(callInst->getType());
if (!resultTy)
return failure();
auto callOp = builder.create<InlineAsmOp>(
loc, resultTy, operands, builder.getStringAttr(asmI->getAsmString()),
builder.getStringAttr(asmI->getConstraintString()),
/*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
/*operand_attrs=*/nullptr);
if (!callInst->getType()->isVoidTy())
mapValue(inst, callOp.getResult(0));
else
mapNoResultOp(inst, callOp);
} else {
auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
// Retrieve the real function type. For direct calls, use the callee's
// function type, as it may differ from the operand type in the case of
// variadic functions. For indirect calls, use the call function type.
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
return convertType(callee->getFunctionType());
return convertType(callInst->getFunctionType());
}());

if (!funcTy)
return failure();
auto callOp = [&]() -> FailureOr<Operation *> {
if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
Type resultTy = convertType(callInst->getType());
if (!resultTy)
return failure();
return builder
.create<InlineAsmOp>(
loc, resultTy, *operands,
builder.getStringAttr(asmI->getAsmString()),
builder.getStringAttr(asmI->getConstraintString()),
/*has_side_effects=*/true,
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
/*operand_attrs=*/nullptr)
.getOperation();
} else {
LLVMFunctionType funcTy = convertFunctionType(callInst);
if (!funcTy)
return failure();

auto callOp = [&]() -> CallOp {
if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
auto name = SymbolRefAttr::get(context, callee->getName());
return builder.create<CallOp>(loc, funcTy, name, operands);
}
return builder.create<CallOp>(loc, funcTy, operands);
}();

// Handle function attributes.
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
callOp.setTailCallKind(
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
setFastmathFlagsAttr(inst, callOp);

callOp.setConvergent(callInst->isConvergent());
callOp.setNoUnwind(callInst->doesNotThrow());
callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));

llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
ModRefInfo othermem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
ModRefInfo argMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
argMem, inaccessibleMem);
// Only set the attribute when it does not match the default value.
if (!memAttr.isReadWrite())
callOp.setMemoryEffectsAttr(memAttr);

if (!callInst->getType()->isVoidTy())
mapValue(inst, callOp.getResult());
else
mapNoResultOp(inst, callOp);
}
FlatSymbolRefAttr callee = convertCalleeName(callInst);
auto callOp = builder.create<CallOp>(loc, funcTy, callee, *operands);
if (failed(convertCallAttributes(callInst, callOp)))
return failure();
return callOp.getOperation();
}
}();

if (failed(callOp))
return failure();

if (!callInst->getType()->isVoidTy())
mapValue(inst, (*callOp)->getResult(0));
else
mapNoResultOp(inst, *callOp);
return success();
}
if (inst->getOpcode() == llvm::Instruction::LandingPad) {
Expand All @@ -1695,9 +1685,11 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (inst->getOpcode() == llvm::Instruction::Invoke) {
auto *invokeInst = cast<llvm::InvokeInst>(inst);

SmallVector<Type> types;
SmallVector<Value> operands;
if (failed(convertCallTypeAndOperands(invokeInst, types, operands)))
if (invokeInst->isInlineAsm())
return emitError(loc) << "invoke of inline assembly is not supported";

FailureOr<SmallVector<Value>> operands = convertCallOperands(invokeInst);
if (failed(operands))
return failure();

// Check whether the invoke result is an argument to the normal destination
Expand All @@ -1724,27 +1716,22 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
unwindArgs)))
return failure();

auto funcTy =
dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
auto funcTy = convertFunctionType(invokeInst);
if (!funcTy)
return failure();

FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);

// Create the invoke operation. Normal destination block arguments will be
// added later on to handle the case in which the operation result is
// included in this list.
InvokeOp invokeOp;
if (llvm::Function *callee = invokeInst->getCalledFunction()) {
invokeOp = builder.create<InvokeOp>(
loc, funcTy,
SymbolRefAttr::get(builder.getContext(), callee->getName()), operands,
directNormalDest, ValueRange(),
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
} else {
invokeOp = builder.create<InvokeOp>(
loc, funcTy, /*callee=*/nullptr, operands, directNormalDest,
ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
}
invokeOp.setCConv(convertCConvFromLLVM(invokeInst->getCallingConv()));
auto invokeOp = builder.create<InvokeOp>(
loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);

if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
return failure();

if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
else
Expand Down Expand Up @@ -2097,6 +2084,41 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
}

template <typename Op>
static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));
return success();
}

LogicalResult ModuleImport::convertInvokeAttributes(llvm::InvokeInst *inst,
InvokeOp op) {
return convertCallBaseAttributes(inst, op);
}

LogicalResult ModuleImport::convertCallAttributes(llvm::CallInst *inst,
CallOp op) {
setFastmathFlagsAttr(inst, op.getOperation());
op.setTailCallKind(convertTailCallKindFromLLVM(inst->getTailCallKind()));
op.setConvergent(inst->isConvergent());
op.setNoUnwind(inst->doesNotThrow());
op.setWillReturn(inst->hasFnAttr(llvm::Attribute::WillReturn));

llvm::MemoryEffects memEffects = inst->getMemoryEffects();
ModRefInfo othermem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
ModRefInfo argMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
auto memAttr = MemoryEffectsAttr::get(op.getContext(), othermem, argMem,
inaccessibleMem);
// Only set the attribute when it does not match the default value.
if (!memAttr.isReadWrite())
op.setMemoryEffectsAttr(memAttr);

return convertCallBaseAttributes(inst, op);
}

LogicalResult ModuleImport::processFunction(llvm::Function *func) {
clearRegionState();

Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Target/LLVMIR/Import/instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,21 @@ define void @fence() {
fence syncscope("") seq_cst
ret void
}

; // -----

; CHECK-LABEL: @f
define void @f() personality ptr @__gxx_personality_v0 {
entry:
; CHECK: llvm.invoke @g() to ^bb1 unwind ^bb2 vararg(!llvm.func<void (...)>) : () -> ()
invoke void @g() to label %bb1 unwind label %bb2
bb1:
ret void
bb2:
%0 = landingpad i32 cleanup
unreachable
}

declare void @g(...)

declare i32 @__gxx_personality_v0(...)