Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,12 @@ class ModuleImport {
/// Converts the callee's function type. For direct calls, it converts the
/// actual function type, which may differ from the called operand type in
/// variadic functions. For indirect calls, it converts the function type
/// associated with the call instruction. Returns failure when the call and
/// the callee are not compatible or when nested type conversions failed.
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
/// associated with the call instruction. When the call and the callee are not
/// compatible (or when nested type conversions failed), emit a warning but
/// attempt translation using a bitcast and an indirect call (in order
/// represent valid and verified LLVM IR).
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst,
Value &castResult);
/// Returns the callee name, or an empty symbol if the call is not direct.
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
/// Converts the parameter and result attributes attached to `func` and adds
Expand Down
62 changes: 46 additions & 16 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1721,8 +1721,8 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
/// Checks if `callType` and `calleeType` are compatible and can be represented
/// in MLIR.
static LogicalResult
verifyFunctionTypeCompatibility(LLVMFunctionType callType,
LLVMFunctionType calleeType) {
checkFunctionTypeCompatibility(LLVMFunctionType callType,
LLVMFunctionType calleeType) {
if (callType.getReturnType() != calleeType.getReturnType())
return failure();

Expand All @@ -1748,7 +1748,7 @@ verifyFunctionTypeCompatibility(LLVMFunctionType callType,
}

FailureOr<LLVMFunctionType>
ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
ModuleImport::convertFunctionType(llvm::CallBase *callInst, Value &castResult) {
auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
if (!funcTy)
Expand All @@ -1771,11 +1771,17 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
if (failed(calleeType))
return failure();

// Compare the types to avoid constructing illegal call/invoke operations.
if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
// Compare the types, if they are not compatible, avoid illegal call/invoke
// operations by casting to the callsite type and issuing an indirect call.
// LLVM IR currently supports this usage.
if (failed(checkFunctionTypeCompatibility(*callType, *calleeType))) {
Location loc = translateLoc(callInst->getDebugLoc());
return emitError(loc) << "incompatible call and callee types: " << *callType
<< " and " << *calleeType;
FlatSymbolRefAttr calleeSym = convertCalleeName(callInst);
castResult = builder.create<LLVM::AddressOfOp>(
loc, LLVM::LLVMPointerType::get(context), calleeSym);
emitWarning(loc) << "incompatible call and callee types: " << *callType
<< " and " << *calleeType;
return callType;
}

return calleeType;
Expand Down Expand Up @@ -1892,16 +1898,29 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
/*operand_attrs=*/nullptr)
.getOperation();
}
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
Value castResult;
FailureOr<LLVMFunctionType> funcTy =
convertFunctionType(callInst, castResult);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr callee = convertCalleeName(callInst);
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
FlatSymbolRefAttr callee = nullptr;
// If no cast is needed, use the original callee name. Otherwise patch
// operands to include the indirect call target. Build indirect call by
// passing using a nullptr `callee`.
if (!castResult)
callee = convertCalleeName(callInst);
else
operands->insert(operands->begin(), castResult);
CallOp callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);

if (failed(convertCallAttributes(callInst, callOp)))
return failure();
// Handle parameter and result attributes.
convertParameterAttributes(callInst, callOp, builder);

// Handle parameter and result attributes. Don't bother if there's a
// type mismatch.
if (!castResult)
convertParameterAttributes(callInst, callOp, builder);
return callOp.getOperation();
}();

Expand Down Expand Up @@ -1966,11 +1985,20 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
unwindArgs)))
return failure();

FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
Value castResult;
FailureOr<LLVMFunctionType> funcTy =
convertFunctionType(invokeInst, castResult);
if (failed(funcTy))
return failure();

FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
FlatSymbolRefAttr calleeName = nullptr;
// If no cast is needed, use the original callee name. Otherwise patch
// operands to include the indirect call target. Build indirect call by
// passing using a nullptr `callee`.
if (!castResult)
calleeName = convertCalleeName(invokeInst);
else
operands->insert(operands->begin(), castResult);

// Create the invoke operation. Normal destination block arguments will be
// added later on to handle the case in which the operation result is
Expand All @@ -1982,8 +2010,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
return failure();

// Handle parameter and result attributes.
convertParameterAttributes(invokeInst, invokeOp, builder);
// Handle parameter and result attributes. Don't bother if there's a
// type mismatch.
if (!castResult)
convertParameterAttributes(invokeInst, invokeOp, builder);

if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Target/LLVMIR/Import/instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,19 @@ bb2:
declare void @g(...)

declare i32 @__gxx_personality_v0(...)

; // -----

; CHECK-LABEL: llvm.func @incompatible_call_and_callee_types
define void @incompatible_call_and_callee_types() {
; CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i64) : i64
; CHECK: %[[TARGET:.*]] = llvm.mlir.addressof @callee : !llvm.ptr
; CHECK: llvm.call %[[TARGET]](%[[CST]]) : !llvm.ptr, (i64) -> ()
call void @callee(i64 0)
; CHECK: llvm.return
ret void
}

define void @callee({ptr, i64}, i32) {
ret void
}