Skip to content

Commit 8cdc16d

Browse files
authored
[MLIR][LLVM] Avoid importing broken calls and invokes (#125041)
This commit adds a check to catch calls/invokes that do not satisfy the type constraints of their callee. This is not verified in LLVM IR but is considered UB. Importing this into MLIR will lead to verification errors, thus we should avoid this early on.
1 parent 4b57236 commit 8cdc16d

File tree

3 files changed

+107
-28
lines changed

3 files changed

+107
-28
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,9 @@ class ModuleImport {
326326
/// Converts the callee's function type. For direct calls, it converts the
327327
/// actual function type, which may differ from the called operand type in
328328
/// variadic functions. For indirect calls, it converts the function type
329-
/// associated with the call instruction.
330-
LLVMFunctionType convertFunctionType(llvm::CallBase *callInst);
329+
/// associated with the call instruction. Returns failure when the call and
330+
/// the callee are not compatible or when nested type conversions failed.
331+
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
331332
/// Returns the callee name, or an empty symbol if the call is not direct.
332333
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
333334
/// Converts the parameter attributes attached to `func` and adds them to

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,22 +1519,72 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
15191519
return operands;
15201520
}
15211521

1522-
LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
1523-
llvm::Value *calledOperand = callInst->getCalledOperand();
1524-
Type converted = [&] {
1525-
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1526-
return convertType(callee->getFunctionType());
1527-
return convertType(callInst->getFunctionType());
1528-
}();
1522+
/// Checks if `callType` and `calleeType` are compatible and can be represented
1523+
/// in MLIR.
1524+
static LogicalResult
1525+
verifyFunctionTypeCompatibility(LLVMFunctionType callType,
1526+
LLVMFunctionType calleeType) {
1527+
if (callType.getReturnType() != calleeType.getReturnType())
1528+
return failure();
1529+
1530+
if (calleeType.isVarArg()) {
1531+
// For variadic functions, the call can have more types than the callee
1532+
// specifies.
1533+
if (callType.getNumParams() < calleeType.getNumParams())
1534+
return failure();
1535+
} else {
1536+
// For non-variadic functions, the number of parameters needs to be the
1537+
// same.
1538+
if (callType.getNumParams() != calleeType.getNumParams())
1539+
return failure();
1540+
}
1541+
1542+
// Check that all operands match.
1543+
for (auto [operandType, argumentType] :
1544+
llvm::zip(callType.getParams(), calleeType.getParams()))
1545+
if (operandType != argumentType)
1546+
return failure();
1547+
1548+
return success();
1549+
}
15291550

1530-
if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
1551+
FailureOr<LLVMFunctionType>
1552+
ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
1553+
auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
1554+
auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
1555+
if (!funcTy)
1556+
return failure();
15311557
return funcTy;
1532-
return {};
1558+
};
1559+
1560+
llvm::Value *calledOperand = callInst->getCalledOperand();
1561+
FailureOr<LLVMFunctionType> callType =
1562+
castOrFailure(convertType(callInst->getFunctionType()));
1563+
if (failed(callType))
1564+
return failure();
1565+
auto *callee = dyn_cast<llvm::Function>(calledOperand);
1566+
// For indirect calls, return the type of the call itself.
1567+
if (!callee)
1568+
return callType;
1569+
1570+
FailureOr<LLVMFunctionType> calleeType =
1571+
castOrFailure(convertType(callee->getFunctionType()));
1572+
if (failed(calleeType))
1573+
return failure();
1574+
1575+
// Compare the types to avoid constructing illegal call/invoke operations.
1576+
if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
1577+
Location loc = translateLoc(callInst->getDebugLoc());
1578+
return emitError(loc) << "incompatible call and callee types: " << *callType
1579+
<< " and " << *calleeType;
1580+
}
1581+
1582+
return calleeType;
15331583
}
15341584

15351585
FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
15361586
llvm::Value *calledOperand = callInst->getCalledOperand();
1537-
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1587+
if (auto *callee = dyn_cast<llvm::Function>(calledOperand))
15381588
return SymbolRefAttr::get(context, callee->getName());
15391589
return {};
15401590
}
@@ -1620,7 +1670,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16201670
return success();
16211671
}
16221672
if (inst->getOpcode() == llvm::Instruction::Call) {
1623-
auto callInst = cast<llvm::CallInst>(inst);
1673+
auto *callInst = cast<llvm::CallInst>(inst);
16241674
llvm::Value *calledOperand = callInst->getCalledOperand();
16251675

16261676
FailureOr<SmallVector<Value>> operands =
@@ -1629,7 +1679,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16291679
return failure();
16301680

16311681
auto callOp = [&]() -> FailureOr<Operation *> {
1632-
if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1682+
if (auto *asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
16331683
Type resultTy = convertType(callInst->getType());
16341684
if (!resultTy)
16351685
return failure();
@@ -1642,17 +1692,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16421692
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
16431693
/*operand_attrs=*/nullptr)
16441694
.getOperation();
1645-
} else {
1646-
LLVMFunctionType funcTy = convertFunctionType(callInst);
1647-
if (!funcTy)
1648-
return failure();
1649-
1650-
FlatSymbolRefAttr callee = convertCalleeName(callInst);
1651-
auto callOp = builder.create<CallOp>(loc, funcTy, callee, *operands);
1652-
if (failed(convertCallAttributes(callInst, callOp)))
1653-
return failure();
1654-
return callOp.getOperation();
16551695
}
1696+
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
1697+
if (failed(funcTy))
1698+
return failure();
1699+
1700+
FlatSymbolRefAttr callee = convertCalleeName(callInst);
1701+
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
1702+
if (failed(convertCallAttributes(callInst, callOp)))
1703+
return failure();
1704+
return callOp.getOperation();
16561705
}();
16571706

16581707
if (failed(callOp))
@@ -1716,8 +1765,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17161765
unwindArgs)))
17171766
return failure();
17181767

1719-
auto funcTy = convertFunctionType(invokeInst);
1720-
if (!funcTy)
1768+
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
1769+
if (failed(funcTy))
17211770
return failure();
17221771

17231772
FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
@@ -1726,7 +1775,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17261775
// added later on to handle the case in which the operation result is
17271776
// included in this list.
17281777
auto invokeOp = builder.create<InvokeOp>(
1729-
loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
1778+
loc, *funcTy, calleeName, *operands, directNormalDest, ValueRange(),
17301779
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
17311780

17321781
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))

mlir/test/Target/LLVMIR/Import/import-failure.ll

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 | FileCheck %s
1+
; RUN: not mlir-translate -import-llvm -emit-expensive-warnings -split-input-file %s 2>&1 -o /dev/null | FileCheck %s
22

33
; CHECK: <unknown>
44
; CHECK-SAME: error: unhandled instruction: indirectbr ptr %dst, [label %bb1, label %bb2]
@@ -353,3 +353,32 @@ declare void @llvm.experimental.noalias.scope.decl(metadata)
353353
; CHECK: import-failure.ll
354354
; CHECK-SAME: warning: unhandled data layout token: ni:42
355355
target datalayout = "e-ni:42-i64:64"
356+
357+
; // -----
358+
359+
; CHECK: <unknown>
360+
; CHECK-SAME: incompatible call and callee types: '!llvm.func<void (i64)>' and '!llvm.func<void (ptr)>'
361+
define void @incompatible_call_and_callee_types() {
362+
call void @callee(i64 0)
363+
ret void
364+
}
365+
366+
declare void @callee(ptr)
367+
368+
; // -----
369+
370+
; CHECK: <unknown>
371+
; CHECK-SAME: incompatible call and callee types: '!llvm.func<void ()>' and '!llvm.func<i32 ()>'
372+
define void @f() personality ptr @__gxx_personality_v0 {
373+
entry:
374+
invoke void @g() to label %bb1 unwind label %bb2
375+
bb1:
376+
ret void
377+
bb2:
378+
%0 = landingpad i32 cleanup
379+
unreachable
380+
}
381+
382+
declare i32 @g()
383+
384+
declare i32 @__gxx_personality_v0(...)

0 commit comments

Comments
 (0)