Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,13 @@ class ModuleImport {
/// Converts the callee's function type. For direct calls, it converts the
/// actual function type, which may differ from the called operand type in
/// variadic functions. For indirect calls, it converts the function type
/// associated with the call instruction. Returns failure when the call and
/// the callee are not compatible or when nested type conversions failed.
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
/// associated with the call instruction. When the call and the callee are not
/// compatible (or when nested type conversions failed), emit a warning but
/// attempt translation using an indirect call (in order to represent valid
/// and verified LLVM IR). The `indirectCallVal` is updated to hold the
/// address for the indirect call.
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst,
Value &indirectCallVal);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
/// Converts the parameter and result attributes attached to `func` and adds
Expand Down
62 changes: 46 additions & 16 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1721,8 +1721,8 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
/// Checks if `callType` and `calleeType` are compatible and can be represented
/// in MLIR.
static LogicalResult
verifyFunctionTypeCompatibility(LLVMFunctionType callType,
LLVMFunctionType calleeType) {
checkFunctionTypeCompatibility(LLVMFunctionType callType,
LLVMFunctionType calleeType) {
if (callType.getReturnType() != calleeType.getReturnType())
return failure();

Expand All @@ -1748,7 +1748,9 @@ verifyFunctionTypeCompatibility(LLVMFunctionType callType,
}

FailureOr<LLVMFunctionType>
ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
ModuleImport::convertFunctionType(llvm::CallBase *callInst,
Value &indirectCallVal) {
indirectCallVal = nullptr;
auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
if (!funcTy)
Expand All @@ -1771,11 +1773,17 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
if (failed(calleeType))
return failure();

// Compare the types to avoid constructing illegal call/invoke operations.
if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
// Compare the types, if they are not compatible, avoid illegal call/invoke
// operations by issuing an indirect call. Note that LLVM IR currently
// supports this usage.
if (failed(checkFunctionTypeCompatibility(*callType, *calleeType))) {
Location loc = translateLoc(callInst->getDebugLoc());
return emitError(loc) << "incompatible call and callee types: " << *callType
<< " and " << *calleeType;
FlatSymbolRefAttr calleeSym = convertCalleeName(callInst);
indirectCallVal = builder.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context), calleeSym);
emitWarning(loc) << "incompatible call and callee types: " << *callType
<< " and " << *calleeType;
return callType;
}

return calleeType;
Expand Down Expand Up @@ -1892,16 +1900,28 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
/*operand_attrs=*/nullptr)
.getOperation();
}
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
Value indirectCallVal;
FailureOr<LLVMFunctionType> funcTy =
convertFunctionType(callInst, indirectCallVal);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr callee = convertCalleeName(callInst);
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
FlatSymbolRefAttr callee = nullptr;
// If `indirectCallVal` is available emit an indirect call, otherwise use
// the callee name. Build an indirect call by passing an empty `callee`
// operand and insert into `operands` to include the indirect call target.
if (indirectCallVal)
operands->insert(operands->begin(), indirectCallVal);
else
callee = convertCalleeName(callInst);
CallOp callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);

if (failed(convertCallAttributes(callInst, callOp)))
return failure();
// Handle parameter and result attributes.
convertParameterAttributes(callInst, callOp, builder);

// Handle parameter and result attributes unless it's an indirect call.
if (!indirectCallVal)
convertParameterAttributes(callInst, callOp, builder);
return callOp.getOperation();
}();

Expand Down Expand Up @@ -1966,11 +1986,20 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
unwindArgs)))
return failure();

FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
Value indirectCallVal;
FailureOr<LLVMFunctionType> funcTy =
convertFunctionType(invokeInst, indirectCallVal);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
FlatSymbolRefAttr calleeName = nullptr;
// If `indirectCallVal` is available emit an indirect call, otherwise use
// the callee name. Build an indirect call by passing an empty `callee`
// operand and insert into `operands` to include the indirect call target.
if (!indirectCallVal)
calleeName = convertCalleeName(invokeInst);
else
operands->insert(operands->begin(), indirectCallVal);

// Create the invoke operation. Normal destination block arguments will be
// added later on to handle the case in which the operation result is
Expand All @@ -1982,8 +2011,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
return failure();

// Handle parameter and result attributes.
convertParameterAttributes(invokeInst, invokeOp, builder);
// Handle parameter and result attributes unless it's an indirect call.
if (!indirectCallVal)
convertParameterAttributes(invokeInst, invokeOp, builder);

if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Target/LLVMIR/Import/instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,19 @@ bb2:
declare void @g(...)

declare i32 @__gxx_personality_v0(...)

; // -----

; CHECK-LABEL: llvm.func @incompatible_call_and_callee_types
define void @incompatible_call_and_callee_types() {
; CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i64) : i64
; CHECK: %[[TARGET:.*]] = llvm.mlir.addressof @callee : !llvm.ptr
; CHECK: llvm.call %[[TARGET]](%[[CST]]) : !llvm.ptr, (i64) -> ()
call void @callee(i64 0)
; CHECK: llvm.return
ret void
}

define void @callee({ptr, i64}, i32) {
ret void
}