Skip to content

Commit 2b6ab6a

Browse files
committed
MLIR: Enable importing inlineasm calls
1 parent 34f0611 commit 2b6ab6a

File tree

4 files changed

+73
-56
lines changed

4 files changed

+73
-56
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,12 @@ class ModuleImport {
319319
/// Appends the converted result type and operands of `callInst` to the
320320
/// `types` and `operands` arrays. For indirect calls, the method additionally
321321
/// inserts the called function at the beginning of the `operands` array.
322+
/// If `handleAsm` is set to false (the default), it will err if the handler
323+
/// is an inline asm which isn't convertible to MLIR as a value.
322324
LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
323325
SmallVectorImpl<Type> &types,
324-
SmallVectorImpl<Value> &operands);
326+
SmallVectorImpl<Value> &operands,
327+
bool handleAsm = false);
325328
/// Converts the parameter attributes attached to `func` and adds them to the
326329
/// `funcOp`.
327330
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,18 +1473,19 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
14731473
return success();
14741474
}
14751475

1476-
LogicalResult
1477-
ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
1478-
SmallVectorImpl<Type> &types,
1479-
SmallVectorImpl<Value> &operands) {
1476+
LogicalResult ModuleImport::convertCallTypeAndOperands(
1477+
llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
1478+
SmallVectorImpl<Value> &operands, bool handleAsm) {
14801479
if (!callInst->getType()->isVoidTy())
14811480
types.push_back(convertType(callInst->getType()));
14821481

14831482
if (!callInst->getCalledFunction()) {
1484-
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
1485-
if (failed(called))
1486-
return failure();
1487-
operands.push_back(*called);
1483+
if (!handleAsm || !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
1484+
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
1485+
if (failed(called))
1486+
return failure();
1487+
operands.push_back(*called);
1488+
}
14881489
}
14891490
SmallVector<llvm::Value *> args(callInst->args());
14901491
FailureOr<SmallVector<Value>> arguments = convertValues(args);
@@ -1579,53 +1580,65 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
15791580

15801581
SmallVector<Type> types;
15811582
SmallVector<Value> operands;
1582-
if (failed(convertCallTypeAndOperands(callInst, types, operands)))
1583+
if (failed(convertCallTypeAndOperands(callInst, types, operands, true)))
15831584
return failure();
15841585

15851586
auto funcTy =
15861587
dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
15871588
if (!funcTy)
15881589
return failure();
15891590

1590-
CallOp callOp;
1591-
1592-
if (llvm::Function *callee = callInst->getCalledFunction()) {
1593-
callOp = builder.create<CallOp>(
1594-
loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
1595-
operands);
1591+
if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
1592+
InlineAsmOp callOp = builder.create<InlineAsmOp>(
1593+
loc, funcTy.getReturnType(), operands,
1594+
builder.getStringAttr(asmI->getAsmString()),
1595+
builder.getStringAttr(asmI->getConstraintString()), nullptr, nullptr,
1596+
nullptr, nullptr);
1597+
if (!callInst->getType()->isVoidTy())
1598+
mapValue(inst, callOp.getResult(0));
1599+
else
1600+
mapNoResultOp(inst, callOp);
15961601
} else {
1597-
callOp = builder.create<CallOp>(loc, funcTy, operands);
1602+
CallOp callOp;
1603+
1604+
if (llvm::Function *callee = callInst->getCalledFunction()) {
1605+
callOp = builder.create<CallOp>(
1606+
loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
1607+
operands);
1608+
} else {
1609+
callOp = builder.create<CallOp>(loc, funcTy, operands);
1610+
}
1611+
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
1612+
callOp.setTailCallKind(
1613+
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
1614+
setFastmathFlagsAttr(inst, callOp);
1615+
1616+
// Handle function attributes.
1617+
if (callInst->hasFnAttr(llvm::Attribute::Convergent))
1618+
callOp.setConvergent(true);
1619+
if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
1620+
callOp.setNoUnwind(true);
1621+
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
1622+
callOp.setWillReturn(true);
1623+
1624+
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
1625+
ModRefInfo othermem = convertModRefInfoFromLLVM(
1626+
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
1627+
ModRefInfo argMem = convertModRefInfoFromLLVM(
1628+
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
1629+
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
1630+
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
1631+
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
1632+
argMem, inaccessibleMem);
1633+
// Only set the attribute when it does not match the default value.
1634+
if (!memAttr.isReadWrite())
1635+
callOp.setMemoryEffectsAttr(memAttr);
1636+
1637+
if (!callInst->getType()->isVoidTy())
1638+
mapValue(inst, callOp.getResult());
1639+
else
1640+
mapNoResultOp(inst, callOp);
15981641
}
1599-
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
1600-
callOp.setTailCallKind(
1601-
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
1602-
setFastmathFlagsAttr(inst, callOp);
1603-
1604-
// Handle function attributes.
1605-
if (callInst->hasFnAttr(llvm::Attribute::Convergent))
1606-
callOp.setConvergent(true);
1607-
if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
1608-
callOp.setNoUnwind(true);
1609-
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
1610-
callOp.setWillReturn(true);
1611-
1612-
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
1613-
ModRefInfo othermem = convertModRefInfoFromLLVM(
1614-
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
1615-
ModRefInfo argMem = convertModRefInfoFromLLVM(
1616-
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
1617-
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
1618-
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
1619-
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem,
1620-
inaccessibleMem);
1621-
// Only set the attribute when it does not match the default value.
1622-
if (!memAttr.isReadWrite())
1623-
callOp.setMemoryEffectsAttr(memAttr);
1624-
1625-
if (!callInst->getType()->isVoidTy())
1626-
mapValue(inst, callOp.getResult());
1627-
else
1628-
mapNoResultOp(inst, callOp);
16291642
return success();
16301643
}
16311644
if (inst->getOpcode() == llvm::Instruction::LandingPad) {

mlir/test/Target/LLVMIR/Import/import-failure.ll

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,6 @@ bb2:
1212

1313
; // -----
1414

15-
; CHECK: <unknown>
16-
; CHECK-SAME: error: unhandled value: ptr asm "bswap $0", "=r,r"
17-
define i32 @unhandled_value(i32 %arg1) {
18-
%1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
19-
ret i32 %1
20-
}
21-
22-
; // -----
23-
2415
; CHECK: <unknown>
2516
; CHECK-SAME: unhandled constant: ptr blockaddress(@unhandled_constant, %bb1) since blockaddress(...) is unsupported
2617
; CHECK: <unknown>

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,16 @@ define void @indirect_vararg_call(ptr addrspace(42) %fn) {
535535

536536
; // -----
537537

538+
; CHECK-LABEL: @inlineasm
539+
; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
540+
define i32 @inlineasm(i32 %arg1) {
541+
%1 = call i32 asm "bswap $0", "=r,r"(i32 %arg1)
542+
; CHECK: llvm.inline_asm "bswap $0", "=r,r" %[[ARG1]] : (i32) -> i32
543+
ret i32 %1
544+
}
545+
546+
; // -----
547+
538548
; CHECK-LABEL: @gep_static_idx
539549
; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
540550
define void @gep_static_idx(ptr %ptr) {

0 commit comments

Comments
 (0)