Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,11 @@ class ModuleImport {
/// Converts the callee's function type. For direct calls, it converts the
/// actual function type, which may differ from the called operand type in
/// variadic functions. For indirect calls, it converts the function type
/// associated with the call instruction. Returns failure when the call and
/// the callee are not compatible or when nested type conversions failed.
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
/// associated with the call instruction. When the call and the callee are not
/// compatible (or when nested type conversions failed), emit a warning and
/// update `isIncompatibleCall` to indicate it.
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst,
bool &isIncompatibleCall);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
/// Converts the parameter and result attributes attached to `func` and adds
Expand Down
74 changes: 57 additions & 17 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1721,8 +1721,8 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
/// Checks if `callType` and `calleeType` are compatible and can be represented
/// in MLIR.
static LogicalResult
verifyFunctionTypeCompatibility(LLVMFunctionType callType,
LLVMFunctionType calleeType) {
checkFunctionTypeCompatibility(LLVMFunctionType callType,
LLVMFunctionType calleeType) {
if (callType.getReturnType() != calleeType.getReturnType())
return failure();

Expand All @@ -1748,7 +1748,9 @@ verifyFunctionTypeCompatibility(LLVMFunctionType callType,
}

FailureOr<LLVMFunctionType>
ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
ModuleImport::convertFunctionType(llvm::CallBase *callInst,
bool &isIncompatibleCall) {
isIncompatibleCall = false;
auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
if (!funcTy)
Expand All @@ -1771,11 +1773,14 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
if (failed(calleeType))
return failure();

// Compare the types to avoid constructing illegal call/invoke operations.
if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
// Compare the types and notify users via `isIncompatibleCall` if they are not
// compatible.
if (failed(checkFunctionTypeCompatibility(*callType, *calleeType))) {
isIncompatibleCall = true;
Location loc = translateLoc(callInst->getDebugLoc());
return emitError(loc) << "incompatible call and callee types: " << *callType
<< " and " << *calleeType;
emitWarning(loc) << "incompatible call and callee types: " << *callType
<< " and " << *calleeType;
return callType;
}

return calleeType;
Expand Down Expand Up @@ -1892,16 +1897,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
/*operand_attrs=*/nullptr)
.getOperation();
}
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
bool isIncompatibleCall;
FailureOr<LLVMFunctionType> funcTy =
convertFunctionType(callInst, isIncompatibleCall);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr callee = convertCalleeName(callInst);
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
FlatSymbolRefAttr callee = nullptr;
if (isIncompatibleCall) {
// Use an indirect call (in order to represent valid and verifiable LLVM
// IR). Build the indirect call by passing an empty `callee` operand and
// insert into `operands` to include the indirect call target.
FlatSymbolRefAttr calleeSym = convertCalleeName(callInst);
Value indirectCallVal = builder.create<LLVM::AddressOfOp>(
translateLoc(callInst->getDebugLoc()),
LLVM::LLVMPointerType::get(context), calleeSym);
operands->insert(operands->begin(), indirectCallVal);
} else {
// Regular direct call using callee name.
callee = convertCalleeName(callInst);
}
CallOp callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);

if (failed(convertCallAttributes(callInst, callOp)))
return failure();
// Handle parameter and result attributes.
convertParameterAttributes(callInst, callOp, builder);

// Handle parameter and result attributes unless it's an incompatible
// call.
if (!isIncompatibleCall)
convertParameterAttributes(callInst, callOp, builder);
return callOp.getOperation();
}();

Expand Down Expand Up @@ -1966,12 +1990,26 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
unwindArgs)))
return failure();

FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
bool isIncompatibleInvoke;
FailureOr<LLVMFunctionType> funcTy =
convertFunctionType(invokeInst, isIncompatibleInvoke);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);

FlatSymbolRefAttr calleeName = nullptr;
if (isIncompatibleInvoke) {
// Use an indirect invoke (in order to represent valid and verifiable LLVM
// IR). Build the indirect invoke by passing an empty `callee` operand and
// insert into `operands` to include the indirect invoke target.
FlatSymbolRefAttr calleeSym = convertCalleeName(invokeInst);
Value indirectInvokeVal = builder.create<LLVM::AddressOfOp>(
translateLoc(invokeInst->getDebugLoc()),
LLVM::LLVMPointerType::get(context), calleeSym);
operands->insert(operands->begin(), indirectInvokeVal);
} else {
// Regular direct invoke using callee name.
calleeName = convertCalleeName(invokeInst);
}
// Create the invoke operation. Normal destination block arguments will be
// added later on to handle the case in which the operation result is
// included in this list.
Expand All @@ -1982,8 +2020,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
return failure();

// Handle parameter and result attributes.
convertParameterAttributes(invokeInst, invokeOp, builder);
// Handle parameter and result attributes unless it's an incompatible
// invoke.
if (!isIncompatibleInvoke)
convertParameterAttributes(invokeInst, invokeOp, builder);

if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Target/LLVMIR/Import/import-failure.ll
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ target datalayout = "e-m-i64:64"
; // -----

; CHECK: <unknown>
; CHECK-SAME: incompatible call and callee types: '!llvm.func<void (i64)>' and '!llvm.func<void (ptr)>'
; CHECK-SAME: warning: incompatible call and callee types: '!llvm.func<void (i64)>' and '!llvm.func<void (ptr)>'
define void @incompatible_call_and_callee_types() {
call void @callee(i64 0)
ret void
Expand All @@ -324,7 +324,7 @@ declare void @callee(ptr)
; // -----

; CHECK: <unknown>
; CHECK-SAME: incompatible call and callee types: '!llvm.func<void ()>' and '!llvm.func<i32 ()>'
; CHECK-SAME: warning: incompatible call and callee types: '!llvm.func<void ()>' and '!llvm.func<i32 ()>'
define void @f() personality ptr @__gxx_personality_v0 {
entry:
invoke void @g() to label %bb1 unwind label %bb2
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Target/LLVMIR/Import/instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -743,3 +743,19 @@ bb2:
declare void @g(...)

declare i32 @__gxx_personality_v0(...)

; // -----

; CHECK-LABEL: llvm.func @incompatible_call_and_callee_types
define void @incompatible_call_and_callee_types() {
; CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i64) : i64
; CHECK: %[[TARGET:.*]] = llvm.mlir.addressof @callee : !llvm.ptr
; CHECK: llvm.call %[[TARGET]](%[[CST]]) : !llvm.ptr, (i64) -> ()
call void @callee(i64 0)
; CHECK: llvm.return
ret void
}

define void @callee({ptr, i64}, i32) {
ret void
}