Skip to content

Commit 668dfb2

Browse files
committed
[mlir][LLVM] Add operand bundle support
1 parent 0de1e3e commit 668dfb2

File tree

7 files changed

+401
-43
lines changed

7 files changed

+401
-43
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,15 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
551551
Variadic<LLVM_Type>:$normalDestOperands,
552552
Variadic<LLVM_Type>:$unwindDestOperands,
553553
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
554-
DefaultValuedAttr<CConv, "CConv::C">:$CConv);
554+
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
555+
VariadicOfVariadic<LLVM_Type,
556+
"op_bundle_sizes">:$op_bundle_operands,
557+
DenseI32ArrayAttr:$op_bundle_sizes,
558+
DefaultValuedProperty<
559+
ArrayProperty<StringProperty, "operand bundle tags">,
560+
"ArrayRef<std::string>{}",
561+
"SmallVector<std::string>{}"
562+
>:$op_bundle_tags);
555563
let results = (outs Optional<LLVM_Type>:$result);
556564
let successors = (successor AnySuccessor:$normalDest,
557565
AnySuccessor:$unwindDest);
@@ -607,7 +615,8 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
607615
//===----------------------------------------------------------------------===//
608616

609617
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
610-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
618+
[AttrSizedOperandSegments,
619+
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
611620
DeclareOpInterfaceMethods<CallOpInterface>,
612621
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
613622
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
@@ -661,8 +670,15 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
661670
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
662671
OptionalAttr<UnitAttr>:$convergent,
663672
OptionalAttr<UnitAttr>:$no_unwind,
664-
OptionalAttr<UnitAttr>:$will_return
665-
);
673+
OptionalAttr<UnitAttr>:$will_return,
674+
VariadicOfVariadic<LLVM_Type,
675+
"op_bundle_sizes">:$op_bundle_operands,
676+
DenseI32ArrayAttr:$op_bundle_sizes,
677+
DefaultValuedProperty<
678+
ArrayProperty<StringProperty, "operand bundle tags">,
679+
"ArrayRef<std::string>{}",
680+
"SmallVector<std::string>{}"
681+
>:$op_bundle_tags);
666682
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
667683
let arguments = !con(args, aliasAttrs);
668684
let results = (outs Optional<LLVM_Type>:$result);
@@ -682,6 +698,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
682698
OpBuilder<(ins "LLVMFunctionType":$calleeType, "StringRef":$callee,
683699
CArg<"ValueRange", "{}">:$args)>
684700
];
701+
let hasVerifier = 1;
685702
let hasCustomAssemblyFormat = 1;
686703
let extraClassDeclaration = [{
687704
/// Returns the callee function type.
@@ -1895,21 +1912,34 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", [DeclareOpInterfaceMethods<MemoryEf
18951912

18961913
def LLVM_CallIntrinsicOp
18971914
: LLVM_Op<"call_intrinsic",
1898-
[DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
1915+
[AttrSizedOperandSegments,
1916+
DeclareOpInterfaceMethods<FastmathFlagsInterface>]> {
18991917
let summary = "Call to an LLVM intrinsic function.";
19001918
let description = [{
19011919
Call the specified llvm intrinsic. If the intrinsic is overloaded, use
19021920
the MLIR function type of this op to determine which intrinsic to call.
19031921
}];
19041922
let arguments = (ins StrAttr:$intrin, Variadic<LLVM_Type>:$args,
19051923
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
1906-
"{}">:$fastmathFlags);
1924+
"{}">:$fastmathFlags,
1925+
VariadicOfVariadic<LLVM_Type,
1926+
"op_bundle_sizes">:$op_bundle_operands,
1927+
DenseI32ArrayAttr:$op_bundle_sizes,
1928+
DefaultValuedProperty<
1929+
ArrayProperty<StringProperty, "operand bundle tags">,
1930+
"ArrayRef<std::string>{}",
1931+
"SmallVector<std::string>{}"
1932+
>:$op_bundle_tags);
19071933
let results = (outs Optional<LLVM_Type>:$results);
19081934
let llvmBuilder = [{
19091935
return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation);
19101936
}];
19111937
let assemblyFormat = [{
1912-
$intrin `(` $args `)` `:` functional-type($args, $results) attr-dict
1938+
$intrin `(` $args `)`
1939+
( custom<OpBundles>($op_bundle_operands, type($op_bundle_operands),
1940+
$op_bundle_tags)^ )?
1941+
`:` functional-type($args, $results)
1942+
attr-dict
19131943
}];
19141944

19151945
let hasVerifier = 1;

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,10 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
544544
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
545545
promoted, callOp->getAttrs());
546546

547+
newOp.getProperties().operandSegmentSizes = {
548+
static_cast<int32_t>(promoted.size()), 0};
549+
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
550+
547551
SmallVector<Value, 4> results;
548552
if (numResults < 2) {
549553
// If < 2 results, packing did not do anything and we can just return.

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -837,17 +837,23 @@ class FunctionCallPattern
837837
matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
838838
ConversionPatternRewriter &rewriter) const override {
839839
if (callOp.getNumResults() == 0) {
840-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
840+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
841841
callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
842+
newOp.getProperties().operandSegmentSizes = {
843+
static_cast<int32_t>(adaptor.getOperands().size()), 0};
844+
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
842845
return success();
843846
}
844847

845848
// Function returns a single result.
846849
auto dstType = typeConverter.convertType(callOp.getType(0));
847850
if (!dstType)
848851
return rewriter.notifyMatchFailure(callOp, "type conversion failed");
849-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
852+
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
850853
callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
854+
newOp.getProperties().operandSegmentSizes = {
855+
static_cast<int32_t>(adaptor.getOperands().size()), 0};
856+
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
851857
return success();
852858
}
853859
};

0 commit comments

Comments
 (0)