diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h index 3c1221e20afbc..84aecbd4373e0 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h @@ -326,8 +326,9 @@ class ModuleImport { /// Converts the callee's function type. For direct calls, it converts the /// actual function type, which may differ from the called operand type in /// variadic functions. For indirect calls, it converts the function type - /// associated with the call instruction. - LLVMFunctionType convertFunctionType(llvm::CallBase *callInst); + /// associated with the call instruction. Returns failure when the call and + /// the callee are not compatible or when nested type conversions failed. + FailureOr convertFunctionType(llvm::CallBase *callInst); /// Returns the callee name, or an empty symbol if the call is not direct. FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst); /// Converts the parameter attributes attached to `func` and adds them to diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 1d1a985c46fb5..e23ffdedd9a60 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -1519,22 +1519,72 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst, return operands; } -LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) { - llvm::Value *calledOperand = callInst->getCalledOperand(); - Type converted = [&] { - if (auto callee = dyn_cast(calledOperand)) - return convertType(callee->getFunctionType()); - return convertType(callInst->getFunctionType()); - }(); +/// Checks if `callType` and `calleeType` are compatible and can be represented +/// in MLIR. +static LogicalResult +verifyFunctionTypeCompatibility(LLVMFunctionType callType, + LLVMFunctionType calleeType) { + if (callType.getReturnType() != calleeType.getReturnType()) + return failure(); + + if (calleeType.isVarArg()) { + // For variadic functions, the call can have more types than the callee + // specifies. + if (callType.getNumParams() < calleeType.getNumParams()) + return failure(); + } else { + // For non-variadic functions, the number of parameters needs to be the + // same. + if (callType.getNumParams() != calleeType.getNumParams()) + return failure(); + } + + // Check that all operands match. + for (auto [operandType, argumentType] : + llvm::zip(callType.getParams(), calleeType.getParams())) + if (operandType != argumentType) + return failure(); + + return success(); +} - if (auto funcTy = dyn_cast_or_null(converted)) +FailureOr +ModuleImport::convertFunctionType(llvm::CallBase *callInst) { + auto castOrFailure = [](Type convertedType) -> FailureOr { + auto funcTy = dyn_cast_or_null(convertedType); + if (!funcTy) + return failure(); return funcTy; - return {}; + }; + + llvm::Value *calledOperand = callInst->getCalledOperand(); + FailureOr callType = + castOrFailure(convertType(callInst->getFunctionType())); + if (failed(callType)) + return failure(); + auto *callee = dyn_cast(calledOperand); + // For indirect calls, return the type of the call itself. + if (!callee) + return callType; + + FailureOr calleeType = + castOrFailure(convertType(callee->getFunctionType())); + if (failed(calleeType)) + return failure(); + + // Compare the types to avoid constructing illegal call/invoke operations. + if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) { + Location loc = translateLoc(callInst->getDebugLoc()); + return emitError(loc) << "incompatible call and callee types: " << *callType + << " and " << *calleeType; + } + + return calleeType; } FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) { llvm::Value *calledOperand = callInst->getCalledOperand(); - if (auto callee = dyn_cast(calledOperand)) + if (auto *callee = dyn_cast(calledOperand)) return SymbolRefAttr::get(context, callee->getName()); return {}; } @@ -1620,7 +1670,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { return success(); } if (inst->getOpcode() == llvm::Instruction::Call) { - auto callInst = cast(inst); + auto *callInst = cast(inst); llvm::Value *calledOperand = callInst->getCalledOperand(); FailureOr> operands = @@ -1629,7 +1679,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { return failure(); auto callOp = [&]() -> FailureOr { - if (auto asmI = dyn_cast(calledOperand)) { + if (auto *asmI = dyn_cast(calledOperand)) { Type resultTy = convertType(callInst->getType()); if (!resultTy) return failure(); @@ -1642,17 +1692,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { /*is_align_stack=*/false, /*asm_dialect=*/nullptr, /*operand_attrs=*/nullptr) .getOperation(); - } else { - LLVMFunctionType funcTy = convertFunctionType(callInst); - if (!funcTy) - return failure(); - - FlatSymbolRefAttr callee = convertCalleeName(callInst); - auto callOp = builder.create(loc, funcTy, callee, *operands); - if (failed(convertCallAttributes(callInst, callOp))) - return failure(); - return callOp.getOperation(); } + FailureOr funcTy = convertFunctionType(callInst); + if (failed(funcTy)) + return failure(); + + FlatSymbolRefAttr callee = convertCalleeName(callInst); + auto callOp = builder.create(loc, *funcTy, callee, *operands); + if (failed(convertCallAttributes(callInst, callOp))) + return failure(); + return callOp.getOperation(); }(); if (failed(callOp)) @@ -1716,8 +1765,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { unwindArgs))) return failure(); - auto funcTy = convertFunctionType(invokeInst); - if (!funcTy) + FailureOr funcTy = convertFunctionType(invokeInst); + if (failed(funcTy)) return failure(); FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst); @@ -1726,7 +1775,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // added later on to handle the case in which the operation result is // included in this list. auto invokeOp = builder.create( - loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(), + loc, *funcTy, calleeName, *operands, directNormalDest, ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs); if (failed(convertInvokeAttributes(invokeInst, invokeOp))) diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll index b616cb81e0a8a..d929a59284762 100644 --- a/mlir/test/Target/LLVMIR/Import/import-failure.ll +++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll @@ -1,4 +1,4 @@ -; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s +; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 -o /dev/null | FileCheck %s ; CHECK: ; CHECK-SAME: error: unhandled instruction: indirectbr ptr %dst, [label %bb1, label %bb2] @@ -353,3 +353,32 @@ declare void @llvm.experimental.noalias.scope.decl(metadata) ; CHECK: import-failure.ll ; CHECK-SAME: warning: unhandled data layout token: ni:42 target datalayout = "e-ni:42-i64:64" + +; // ----- + +; CHECK: +; CHECK-SAME: incompatible call and callee types: '!llvm.func' and '!llvm.func' +define void @incompatible_call_and_callee_types() { + call void @callee(i64 0) + ret void +} + +declare void @callee(ptr) + +; // ----- + +; CHECK: +; CHECK-SAME: incompatible call and callee types: '!llvm.func' and '!llvm.func' +define void @f() personality ptr @__gxx_personality_v0 { +entry: + invoke void @g() to label %bb1 unwind label %bb2 +bb1: + ret void +bb2: + %0 = landingpad i32 cleanup + unreachable +} + +declare i32 @g() + +declare i32 @__gxx_personality_v0(...)