Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,9 @@ class ModuleImport {
/// Converts the callee's function type. For direct calls, it converts the
/// actual function type, which may differ from the called operand type in
/// variadic functions. For indirect calls, it converts the function type
/// associated with the call instruction.
LLVMFunctionType convertFunctionType(llvm::CallBase *callInst);
/// associated with the call instruction. Returns failure when the call and
/// the callee are not compatible or when nested type conversions failed.
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
/// Converts the parameter attributes attached to `func` and adds them to
Expand Down
99 changes: 74 additions & 25 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1519,22 +1519,72 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
return operands;
}

LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
llvm::Value *calledOperand = callInst->getCalledOperand();
Type converted = [&] {
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
return convertType(callee->getFunctionType());
return convertType(callInst->getFunctionType());
}();
/// Checks if `callType` and `calleeType` are compatible and can be represented
/// in MLIR.
static LogicalResult
verifyFunctionTypeCompatibility(LLVMFunctionType callType,
LLVMFunctionType calleeType) {
if (callType.getReturnType() != calleeType.getReturnType())
return failure();

if (calleeType.isVarArg()) {
// For variadic functions, the call can have more types than the callee
// specifies.
if (callType.getNumParams() < calleeType.getNumParams())
return failure();
} else {
// For non-variadic functions, the number of parameters needs to be the
// same.
if (callType.getNumParams() != calleeType.getNumParams())
return failure();
}

// Check that all operands match.
for (auto [operandType, argumentType] :
llvm::zip(callType.getParams(), calleeType.getParams()))
if (operandType != argumentType)
return failure();

return success();
}

if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
FailureOr<LLVMFunctionType>
ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
if (!funcTy)
return failure();
return funcTy;
return {};
};

llvm::Value *calledOperand = callInst->getCalledOperand();
FailureOr<LLVMFunctionType> callType =
castOrFailure(convertType(callInst->getFunctionType()));
if (failed(callType))
return failure();
auto *callee = dyn_cast<llvm::Function>(calledOperand);
// For indirect calls, return the type of the call itself.
if (!callee)
return callType;

FailureOr<LLVMFunctionType> calleeType =
castOrFailure(convertType(callee->getFunctionType()));
if (failed(calleeType))
return failure();

// Compare the types to avoid constructing illegal call/invoke operations.
if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
Location loc = translateLoc(callInst->getDebugLoc());
return emitError(loc) << "incompatible call and callee types: " << *callType
<< " and " << *calleeType;
}

return calleeType;
}

FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
llvm::Value *calledOperand = callInst->getCalledOperand();
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
if (auto *callee = dyn_cast<llvm::Function>(calledOperand))
return SymbolRefAttr::get(context, callee->getName());
return {};
}
Expand Down Expand Up @@ -1620,7 +1670,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
return success();
}
if (inst->getOpcode() == llvm::Instruction::Call) {
auto callInst = cast<llvm::CallInst>(inst);
auto *callInst = cast<llvm::CallInst>(inst);
llvm::Value *calledOperand = callInst->getCalledOperand();

FailureOr<SmallVector<Value>> operands =
Expand All @@ -1629,7 +1679,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
return failure();

auto callOp = [&]() -> FailureOr<Operation *> {
if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
if (auto *asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
Type resultTy = convertType(callInst->getType());
if (!resultTy)
return failure();
Expand All @@ -1642,17 +1692,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
/*operand_attrs=*/nullptr)
.getOperation();
} else {
LLVMFunctionType funcTy = convertFunctionType(callInst);
if (!funcTy)
return failure();

FlatSymbolRefAttr callee = convertCalleeName(callInst);
auto callOp = builder.create<CallOp>(loc, funcTy, callee, *operands);
if (failed(convertCallAttributes(callInst, callOp)))
return failure();
return callOp.getOperation();
}
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr callee = convertCalleeName(callInst);
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
if (failed(convertCallAttributes(callInst, callOp)))
return failure();
return callOp.getOperation();
}();

if (failed(callOp))
Expand Down Expand Up @@ -1716,8 +1765,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
unwindArgs)))
return failure();

auto funcTy = convertFunctionType(invokeInst);
if (!funcTy)
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
Expand All @@ -1726,7 +1775,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// added later on to handle the case in which the operation result is
// included in this list.
auto invokeOp = builder.create<InvokeOp>(
loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
loc, *funcTy, calleeName, *operands, directNormalDest, ValueRange(),
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);

if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
Expand Down
31 changes: 30 additions & 1 deletion mlir/test/Target/LLVMIR/Import/import-failure.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s
; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 -o /dev/null | FileCheck %s

; CHECK: <unknown>
; CHECK-SAME: error: unhandled instruction: indirectbr ptr %dst, [label %bb1, label %bb2]
Expand Down Expand Up @@ -353,3 +353,32 @@ declare void @llvm.experimental.noalias.scope.decl(metadata)
; CHECK: import-failure.ll
; CHECK-SAME: warning: unhandled data layout token: ni:42
target datalayout = "e-ni:42-i64:64"

; // -----

; CHECK: <unknown>
; CHECK-SAME: incompatible call and callee types: '!llvm.func<void (i64)>' and '!llvm.func<void (ptr)>'
define void @incompatible_call_and_callee_types() {
call void @callee(i64 0)
ret void
}

declare void @callee(ptr)

; // -----

; CHECK: <unknown>
; CHECK-SAME: incompatible call and callee types: '!llvm.func<void ()>' and '!llvm.func<i32 ()>'
define void @f() personality ptr @__gxx_personality_v0 {
entry:
invoke void @g() to label %bb1 unwind label %bb2
bb1:
ret void
bb2:
%0 = landingpad i32 cleanup
unreachable
}

declare i32 @g()

declare i32 @__gxx_personality_v0(...)