Skip to content

Commit 28c1926

Browse files
committed
[MLIR][LLVM] Tail call support for inline asm op
1 parent 744a469 commit 28c1926

File tree

11 files changed

+58
-9
lines changed

11 files changed

+58
-9
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -758,9 +758,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
758758
the LLVM function type that uses an explicit void type to model functions
759759
that do not return a value.
760760

761-
If this operatin has the `no_inline` attribute, then this specific function call
762-
will never be inlined. The opposite behavior will occur if the call has `always_inline`
763-
attribute. The `inline_hint` attribute indicates that it is desirable to inline
761+
If this operatin has the `no_inline` attribute, then this specific function call
762+
will never be inlined. The opposite behavior will occur if the call has `always_inline`
763+
attribute. The `inline_hint` attribute indicates that it is desirable to inline
764764
this function call.
765765

766766
Examples:
@@ -2298,13 +2298,18 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
22982298
written, or referenced.
22992299
Attempting to define or reference any symbol or any global behavior is
23002300
considered undefined behavior at this time.
2301+
If `tail_call_kind` is used, the operation behaves like the specified
2302+
tail call kind. The `musttail` kind it's not available for this operation,
2303+
since it isn't supported by LLVM's inline asm.
23012304
}];
23022305
let arguments = (
23032306
ins Variadic<LLVM_Type>:$operands,
23042307
StrAttr:$asm_string,
23052308
StrAttr:$constraints,
23062309
UnitAttr:$has_side_effects,
23072310
UnitAttr:$is_align_stack,
2311+
OptionalAttr<
2312+
DefaultValuedAttr<TailCallKind, "TailCallKind::None">>:$tail_call_kind,
23082313
OptionalAttr<
23092314
DefaultValuedAttr<AsmATTOrIntel, "AsmDialect::AD_ATT">>:$asm_dialect,
23102315
OptionalAttr<ArrayAttr>:$operand_attrs);
@@ -2314,6 +2319,7 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
23142319
let assemblyFormat = [{
23152320
(`has_side_effects` $has_side_effects^)?
23162321
(`is_align_stack` $is_align_stack^)?
2322+
(`tail_call_kind` `=` $tail_call_kind^)?
23172323
(`asm_dialect` `=` $asm_dialect^)?
23182324
(`operand_attrs` `=` $operand_attrs^)?
23192325
attr-dict
@@ -2326,6 +2332,8 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
23262332
return "elementtype";
23272333
}
23282334
}];
2335+
2336+
let hasVerifier = 1;
23292337
}
23302338

23312339
//===--------------------------------------------------------------------===//

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,8 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
439439
op,
440440
/*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
441441
/*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
442-
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
442+
/*is_align_stack=*/false, /*tail_call_kind=*/nullptr,
443+
/*asm_dialect=*/asmDialectAttr,
443444
/*operand_attrs=*/ArrayAttr());
444445
return success();
445446
}

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,7 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
572572
/*constraints=*/constraintStr,
573573
/*has_side_effects=*/true,
574574
/*is_align_stack=*/false,
575+
/*tail_call_kind=*/nullptr,
575576
/*asm_dialect=*/asmDialectAttr,
576577
/*operand_attrs=*/ArrayAttr());
577578
}

mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
140140
/*constraints=*/registerConstraints.data(),
141141
/*has_side_effects=*/interfaceOp.hasSideEffect(),
142142
/*is_align_stack=*/false,
143+
/*tail_call_kind=*/nullptr,
143144
/*asm_dialect=*/asmDialectAttr,
144145
/*operand_attrs=*/ArrayAttr());
145146
}

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4042,6 +4042,21 @@ LogicalResult LLVM::masked_scatter::verify() {
40424042
return success();
40434043
}
40444044

4045+
//===----------------------------------------------------------------------===//
4046+
// InlineAsmOp
4047+
//===----------------------------------------------------------------------===//
4048+
4049+
LogicalResult InlineAsmOp::verify() {
4050+
if (!getTailCallKindAttr())
4051+
return success();
4052+
4053+
if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
4054+
return emitOpError(
4055+
"tail call kind 'musttail' is not supported by this operation");
4056+
4057+
return success();
4058+
}
4059+
40454060
//===----------------------------------------------------------------------===//
40464061
// LLVMDialect initialization, type parsing, and registration.
40474062
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
4141
auto asmOp = b.create<LLVM::InlineAsmOp>(
4242
v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
4343
/*constraints=*/asmCstr, /*has_side_effects=*/false,
44-
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
44+
/*is_align_stack=*/false, /*tail_call_kind=*/nullptr,
45+
/*asm_dialect=*/asmDialectAttr,
4546
/*operand_attrs=*/ArrayAttr());
4647
return asmOp.getResult(0);
4748
}

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/ADT/TypeSwitch.h"
2020
#include "llvm/IR/IRBuilder.h"
2121
#include "llvm/IR/InlineAsm.h"
22+
#include "llvm/IR/Instructions.h"
2223
#include "llvm/IR/MDBuilder.h"
2324
#include "llvm/IR/MatrixBuilder.h"
2425
#include "llvm/IR/Operator.h"
@@ -507,6 +508,9 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
507508
llvm::CallInst *inst = builder.CreateCall(
508509
inlineAsmInst,
509510
moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
511+
if (inlineAsmOp.getTailCallKindAttr())
512+
inst->setTailCallKind(convertTailCallKindToLLVM(
513+
inlineAsmOp.getTailCallKindAttr().getTailCallKind()));
510514
if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
511515
llvm::AttributeList attrList;
512516
for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) {

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,6 +2201,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
22012201
builder.getStringAttr(asmI->getAsmString()),
22022202
builder.getStringAttr(asmI->getConstraintString()),
22032203
asmI->hasSideEffects(), asmI->isAlignStack(),
2204+
callInst->isTailCall()
2205+
? TailCallKindAttr::get(mlirModule.getContext(),
2206+
TailCallKind::Tail)
2207+
: nullptr,
22042208
AsmDialectAttr::get(
22052209
mlirModule.getContext(),
22062210
convertAsmDialectFromLLVM(asmI->getDialect())),

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,3 +1882,11 @@ llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2
18821882
%0 = llvm.mlir.constant([2.5, 7.4]) : !llvm.array<2 x f64>
18831883
llvm.return %0 : !llvm.array<2 x f64>
18841884
}
1885+
1886+
// ----
1887+
1888+
llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) {
1889+
// expected-error@+1 {{op tail call kind 'musttail' is not supported}}
1890+
%8 = llvm.inline_asm tail_call_kind = <musttail> "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)>
1891+
llvm.return
1892+
}

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,8 @@ define i32 @inlineasm(i32 %arg1) {
554554
define void @inlineasm2() {
555555
%p = alloca ptr, align 8
556556
; CHECK: {{.*}} = llvm.alloca %0 x !llvm.ptr {alignment = 8 : i64} : (i32) -> !llvm.ptr
557-
; CHECK-NEXT: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
558-
call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
557+
; CHECK-NEXT: llvm.inline_asm has_side_effects tail_call_kind = <tail> asm_dialect = att operand_attrs = [{elementtype = !llvm.ptr}] "", "*m,~{memory}" {{.*}} : (!llvm.ptr) -> !llvm.void
558+
tail call void asm sideeffect "", "*m,~{memory}"(ptr elementtype(ptr) %p)
559559
ret void
560560
}
561561

0 commit comments

Comments
 (0)