|
11 | 11 |
|
12 | 12 | include "MicrokernelDialect.td" |
13 | 13 | include "gc/Dialect/Microkernel/MicrokernelEnum.td" |
| 14 | +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" |
| 15 | +include "mlir/Interfaces/DestinationStyleOpInterface.td" |
14 | 16 | include "mlir/Interfaces/SideEffectInterfaces.td" |
15 | 17 |
|
| 18 | +class StaticTensorRankOf<list<Type> allowedTypes, list<int> ranks> : |
| 19 | + Type<And<[TensorOf<allowedTypes>.predicate, |
| 20 | + HasAnyRankOfPred<ranks>, HasStaticShapePred]>, |
| 21 | + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " static " # |
| 22 | + TensorOf<allowedTypes>.summary, "::mlir::TensorType">; |
| 23 | + |
16 | 24 | class StaticMemRefRankOf<list<Type> allowedTypes, list<int> ranks> : |
17 | 25 | Type<And<[MemRefOf<allowedTypes>.predicate, |
18 | 26 | HasAnyRankOfPred<ranks>, HasStaticShapePred]>, |
19 | 27 | !interleave(!foreach(rank, ranks, rank # "D"), "/") # " static " # |
20 | 28 | MemRefOf<allowedTypes>.summary, "::mlir::MemRefType">; |
21 | 29 |
|
| 30 | +def BrgemmTensor : StaticTensorRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>; |
| 31 | + |
| 32 | +def BrgemmTensorOrMemRef : AnyTypeOf<[StaticTensorRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, |
| 33 | + StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>]>; |
| 34 | + |
| 35 | +def Microkernel_BrgemmOp : Microkernel_Op<"brgemm", |
| 36 | + [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, |
| 37 | + BufferizableOpInterface, |
| 38 | + DestinationStyleOpInterface]> { |
| 39 | + let summary = "Abstract Op that execute brgemm kernel on tensors."; |
| 40 | + let description = [{ |
| 41 | + The operation has the following arguments: |
| 42 | + 1) Tensors or MemRefs of operand A/B; |
| 43 | + 2) The batch dims and leading dims of operand A/B; |
| 44 | + And has the following outputs: |
| 45 | + 1) Tensor of operand C; |
| 46 | + }]; |
| 47 | + |
| 48 | + let arguments = (ins Variadic<BrgemmTensorOrMemRef>:$inputs, |
| 49 | + BrgemmTensorOrMemRef:$init, |
| 50 | + ConfinedAttr<DenseI64ArrayAttr, |
| 51 | + [DenseArrayNonNegative<DenseI64ArrayAttr>]>:$batchDims, |
| 52 | + ConfinedAttr<DenseI64ArrayAttr, |
| 53 | + [DenseArrayNonNegative<DenseI64ArrayAttr>]>:$leadingDims, |
| 54 | + TypedArrayAttrBase<Microkernel_BrgemmFlags, "brgemm flags">:$flags); |
| 55 | + let results = (outs Variadic<BrgemmTensor>:$output); |
| 56 | + |
| 57 | + let extraClassDeclaration = [{ |
| 58 | + Value getOperandA() { return getInputs()[0]; } |
| 59 | + Value getOperandB() { return getInputs()[1]; } |
| 60 | + Value getOperandC() { return getInit(); } |
| 61 | + |
| 62 | + int64_t getBatchDimA() { return getBatchDims()[0]; } |
| 63 | + int64_t getLeadingDimA() { return getLeadingDims()[0]; } |
| 64 | + |
| 65 | + int64_t getBatchDimB() { return getBatchDims()[1]; } |
| 66 | + int64_t getLeadingDimB() { return getLeadingDims()[1]; } |
| 67 | + |
| 68 | + MutableOperandRange getDpsInitsMutable() { return getInitMutable(); } |
| 69 | + |
| 70 | + bool bufferizesToMemoryRead(OpOperand &, |
| 71 | + const bufferization::AnalysisState &); |
| 72 | + bool bufferizesToMemoryWrite(OpOperand &, |
| 73 | + const bufferization::AnalysisState &); |
| 74 | + bool bufferizesToElementwiseAccess(const bufferization::AnalysisState &, |
| 75 | + ArrayRef<OpOperand *>); |
| 76 | + bufferization::AliasingValueList getAliasingValues(OpOperand &opOperand, |
| 77 | + const bufferization::AnalysisState &state); |
| 78 | + LogicalResult bufferize(RewriterBase &, |
| 79 | + const bufferization::BufferizationOptions &); |
| 80 | + }]; |
| 81 | + |
| 82 | + let hasVerifier = 1; |
| 83 | + let hasCustomAssemblyFormat = 1; |
| 84 | + let hasFolder = 1; |
| 85 | +} |
| 86 | + |
22 | 87 | def Microkernel_BrgemmDispatchOp : Microkernel_Op<"brgemm.dispatch", [Pure]> { |
23 | 88 | let summary = "JIT the brgemm microkernel given the parameters"; |
24 | 89 | let description = [{ |
@@ -80,7 +145,7 @@ def Microkernel_BrgemmEpilogueOp : Microkernel_Op<"brgemm.epilogue"> { |
80 | 145 | */ |
81 | 146 | def BrgemmMemRefOrI64 : AnyTypeOf<[StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, I64]>; |
82 | 147 |
|
83 | | -def Microkernel_BrgemmOp : Microkernel_Op<"brgemm"> { |
| 148 | +def Microkernel_BrgemmExecuteOp : Microkernel_Op<"brgemm.execute"> { |
84 | 149 | let summary = "execute the JITed brgemm kernel."; |
85 | 150 | let description = [{ |
86 | 151 | The operation has the following arguments: |
|
0 commit comments