Skip to content

Commit b5f2167

Browse files
authored
MLIR: Enable importing inlineasm calls (#121624)
1 parent f48884d commit b5f2167

File tree

4 files changed

+79
-56
lines changed

4 files changed

+79
-56
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,13 @@ class ModuleImport {
319319
/// Appends the converted result type and operands of `callInst` to the
320320
/// `types` and `operands` arrays. For indirect calls, the method additionally
321321
/// inserts the called function at the beginning of the `operands` array.
322+
/// If `allowInlineAsm` is set to false (the default), it will return failure
323+
/// if the called operand is an inline asm which isn't convertible to MLIR as
324+
/// a value.
322325
LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
323326
SmallVectorImpl<Type> &types,
324-
SmallVectorImpl<Value> &operands);
327+
SmallVectorImpl<Value> &operands,
328+
bool allowInlineAsm = false);
325329
/// Converts the parameter attributes attached to `func` and adds them to the
326330
/// `funcOp`.
327331
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,18 +1473,20 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
14731473
return success();
14741474
}
14751475

1476-
LogicalResult
1477-
ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
1478-
SmallVectorImpl<Type> &types,
1479-
SmallVectorImpl<Value> &operands) {
1476+
LogicalResult ModuleImport::convertCallTypeAndOperands(
1477+
llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
1478+
SmallVectorImpl<Value> &operands, bool allowInlineAsm) {
14801479
if (!callInst->getType()->isVoidTy())
14811480
types.push_back(convertType(callInst->getType()));
14821481

14831482
if (!callInst->getCalledFunction()) {
1484-
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
1485-
if (failed(called))
1486-
return failure();
1487-
operands.push_back(*called);
1483+
if (!allowInlineAsm ||
1484+
!isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
1485+
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
1486+
if (failed(called))
1487+
return failure();
1488+
operands.push_back(*called);
1489+
}
14881490
}
14891491
SmallVector<llvm::Value *> args(callInst->args());
14901492
FailureOr<SmallVector<Value>> arguments = convertValues(args);
@@ -1579,53 +1581,68 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
15791581

15801582
SmallVector<Type> types;
15811583
SmallVector<Value> operands;
1582-
if (failed(convertCallTypeAndOperands(callInst, types, operands)))
1584+
if (failed(convertCallTypeAndOperands(callInst, types, operands,
1585+
/*allowInlineAsm=*/true)))
15831586
return failure();
15841587

15851588
auto funcTy =
15861589
dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
15871590
if (!funcTy)
15881591
return failure();
15891592

1590-
CallOp callOp;
1591-
1592-
if (llvm::Function *callee = callInst->getCalledFunction()) {
1593-
callOp = builder.create<CallOp>(
1594-
loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
1595-
operands);
1593+
if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
1594+
auto callOp = builder.create<InlineAsmOp>(
1595+
loc, funcTy.getReturnType(), operands,
1596+
builder.getStringAttr(asmI->getAsmString()),
1597+
builder.getStringAttr(asmI->getConstraintString()),
1598+
/*has_side_effects=*/true,
1599+
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
1600+
/*operand_attrs=*/nullptr);
1601+
if (!callInst->getType()->isVoidTy())
1602+
mapValue(inst, callOp.getResult(0));
1603+
else
1604+
mapNoResultOp(inst, callOp);
15961605
} else {
1597-
callOp = builder.create<CallOp>(loc, funcTy, operands);
1606+
CallOp callOp;
1607+
1608+
if (llvm::Function *callee = callInst->getCalledFunction()) {
1609+
callOp = builder.create<CallOp>(
1610+
loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
1611+
operands);
1612+
} else {
1613+
callOp = builder.create<CallOp>(loc, funcTy, operands);
1614+
}
1615+
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
1616+
callOp.setTailCallKind(
1617+
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
1618+
setFastmathFlagsAttr(inst, callOp);
1619+
1620+
// Handle function attributes.
1621+
if (callInst->hasFnAttr(llvm::Attribute::Convergent))
1622+
callOp.setConvergent(true);
1623+
if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
1624+
callOp.setNoUnwind(true);
1625+
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
1626+
callOp.setWillReturn(true);
1627+
1628+
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
1629+
ModRefInfo othermem = convertModRefInfoFromLLVM(
1630+
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
1631+
ModRefInfo argMem = convertModRefInfoFromLLVM(
1632+
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
1633+
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
1634+
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
1635+
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
1636+
argMem, inaccessibleMem);
1637+
// Only set the attribute when it does not match the default value.
1638+
if (!memAttr.isReadWrite())
1639+
callOp.setMemoryEffectsAttr(memAttr);
1640+
1641+
if (!callInst->getType()->isVoidTy())
1642+
mapValue(inst, callOp.getResult());
1643+
else
1644+
mapNoResultOp(inst, callOp);
15981645
}
1599-
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
1600-
callOp.setTailCallKind(
1601-
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
1602-
setFastmathFlagsAttr(inst, callOp);
1603-
1604-
// Handle function attributes.
1605-
if (callInst->hasFnAttr(llvm::Attribute::Convergent))
1606-
callOp.setConvergent(true);
1607-
if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
1608-
callOp.setNoUnwind(true);
1609-
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
1610-
callOp.setWillReturn(true);
1611-
1612-
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
1613-
ModRefInfo othermem = convertModRefInfoFromLLVM(
1614-
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
1615-
ModRefInfo argMem = convertModRefInfoFromLLVM(
1616-
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
1617-
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
1618-
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
1619-
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem,
1620-
inaccessibleMem);
1621-
// Only set the attribute when it does not match the default value.
1622-
if (!memAttr.isReadWrite())
1623-
callOp.setMemoryEffectsAttr(memAttr);
1624-
1625-
if (!callInst->getType()->isVoidTy())
1626-
mapValue(inst, callOp.getResult());
1627-
else
1628-
mapNoResultOp(inst, callOp);
16291646
return success();
16301647
}
16311648
if (inst->getOpcode() == llvm::Instruction::LandingPad) {

mlir/test/Target/LLVMIR/Import/import-failure.ll

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@ bb2:
1212

1313
; // -----
1414

15-
; CHECK: <unknown>
16-
; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
17-
define i32 @unhandled_value(i32 %arg1) {
18-
%1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
19-
ret i32 %1
20-
}
21-
22-
; // -----
23-
2415
; CHECK: <unknown>
2516
; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported
2617
; CHECK: <unknown>

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,17 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) {
535535

536536
; // -----
537537

538+
; CHECK-LABEL: @inlineasm
539+
; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
540+
define i32 @inlineasm(i32 %arg1) {
541+
; CHECK: %[[RES:.+]] = llvm.inline_asm has_side_effects "bswap $0", "=r,r" %[[ARG1]] : (i32) -> i32
542+
%1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
543+
; CHECK: return %[[RES]]
544+
ret i32 %1
545+
}
546+
547+
; // -----
548+
538549
; CHECK-LABEL: @gep_static_idx
539550
; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
540551
define void @gep_static_idx(ptr %ptr) {

0 commit comments

Comments
 (0)