1515
1616include "mlir/Dialect/Rock/IR/RockAttrDefs.td"
1717include "mlir/Dialect/Rock/IR/RockConvInterface.td"
18+ include "mlir/Dialect/Rock/IR/RockGemmFeaturesInterface.td"
1819include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td"
1920include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.td"
2021include "mlir/Dialect/Rock/IR/RockAcceptingViewOpInterface.td"
@@ -61,11 +62,13 @@ class TensorOrMemRefOf<list<Type> allowedTypes> :
6162
6263class IndexArrayLength<int n> : ConfinedAttr<IndexArrayAttr, [ArrayMinCount<n>]>;
6364
64- class Rock_ConvOpBase<string mnemonic, list<Type> inputTypes=[F32, F16, BF16], list<Type> outputTypes=[F32, F16, BF16]> :
65- Rock_Op<mnemonic, [DeclareOpInterfaceMethods<RockGemmWrapperInterface>,
66- DeclareOpInterfaceMethods<RockConvInterface>,
67- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
68- RockFusionRoot]>{
65+ class Rock_ConvOpBase<string mnemonic, list<Type> inputTypes = [F32, F16, BF16],
66+ list<Type> outputTypes = [F32, F16, BF16]>
67+ : Rock_Op<mnemonic, [DeclareOpInterfaceMethods<RockGemmWrapperInterface>,
68+ DeclareOpInterfaceMethods<RockConvInterface>,
69+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
70+ DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
71+ RockFusionRoot]> {
6972 dag commonConvArgs = (ins TensorOrMemRefOf<inputTypes>:$filter,
7073 TensorOrMemRefOf<inputTypes>:$input,
7174 TensorOrMemRefOf<outputTypes>:$output,
@@ -132,6 +135,7 @@ def Rock_ConvBwdWeightOp : Rock_ConvOpBase<"conv_bwd_weight">
132135
133136def Rock_GemmOp
134137 : Rock_Op<"gemm", [DeclareOpInterfaceMethods<RockGemmWrapperInterface>,
138+ DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
135139 RockFusionRoot]>,
136140 Arguments<(
137141 ins Arg<TensorOrMemRefOf<GemmInputTypes>, "matrix A", [MemRead]>:$a,
@@ -187,12 +191,17 @@ def Rock_ReduceOp
187191 }];
188192 let extraClassDeclaration = [{
189193 ::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperand(1); }
194+
195+ SmallVector<::mlir::Type> getTypesForFeature() {
196+ return {getIn().getType()};
197+ }
190198 }];
191199}
192200def Rock_AttentionOp
193201 : Rock_Op<
194202 "attention", [DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>,
195203 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
204+ DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
196205 RockFusionRoot, AttrSizedOperandSegments,
197206 AttrSizedResultSegments]>,
198207 Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
@@ -260,6 +269,7 @@ def Rock_GemmElementwiseGemmOp
260269 : Rock_Op<"gemm_elementwise_gemm",
261270 [DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>,
262271 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
272+ DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
263273 RockFusionRoot]>,
264274 AllElementTypesMatch<["a", "b"]>,
265275 Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16]>:$a,
@@ -310,6 +320,7 @@ def Rock_ConvElementwiseGemmOp
310320 : Rock_Op<"conv_elementwise_gemm",
311321 [DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>,
312322 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
323+ DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
313324 RockFusionRoot]>,
314325 AllElementTypesMatch<["filter", "input"]>,
315326 Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16]>:$filter,
@@ -486,7 +497,9 @@ def Rock_TensorUntransformCastOp :
486497}
487498
488499def Rock_GridwiseGemmOp
489- : Rock_Op<"gridwise_gemm", [RockFusionRoot]>,
500+ : Rock_Op<"gridwise_gemm", [DeclareOpInterfaceMethods<
501+ RockGemmFeaturesInterface>,
502+ RockFusionRoot]>,
490503 Arguments<(ins Arg<MemRefRankOf<GemmInputTypes, [3]>,
491504 "matrix A view", [MemRead]>:$a,
492505 Arg<MemRefRankOf<GemmInputTypes, [3]>, "matrix B view", [MemRead]>:$b,
@@ -503,19 +516,13 @@ def Rock_GridwiseGemmOp
503516 $c `=` $a `*` $b `storeMethod` `(` $storeMethod `)` (`features` `=` $features^)? attr-dict `:` type($c) `=` type($a) `*` type($b)
504517 }];
505518 let hasVerifier = 1;
506-
507- // Return the type from this op that is needed to calculate GemmFeatures
508- let extraClassDeclaration = [{
509- SmallVector<::mlir::Type> getTypesForFeature() {
510- SmallVector<::mlir::Type> types = {getA().getType()};
511- return types;
512- }
513- }];
514519}
515520
516521// gridwise_gemm_accel
517522def Rock_GridwiseGemmAccelOp
518- : Rock_Op<"gridwise_gemm_accel", [RockFusionRoot]>,
523+ : Rock_Op<"gridwise_gemm_accel", [DeclareOpInterfaceMethods<
524+ RockGemmFeaturesInterface>,
525+ RockFusionRoot]>,
519526 Arguments<(ins Arg<MemRefRankOf<GemmInputTypes, [3]>,
520527 "matrix A view", [MemRead]>:$a,
521528 Arg<MemRefRankOf<GemmInputTypes, [3]>, "matrix B view", [MemRead]>:$b,
@@ -532,21 +539,15 @@ def Rock_GridwiseGemmAccelOp
532539 `(` operands `)` `storeMethod` `(` $storeMethod `)` (`features` `=` $features^)? attr-dict `:` type(operands)
533540 }];
534541 let hasVerifier = 1;
535-
536- // Return the type from this op that is needed to calculate GemmFeatures
537- let extraClassDeclaration = [{
538- SmallVector<::mlir::Type> getTypesForFeature() {
539- SmallVector<::mlir::Type> types = {getA().getType()};
540- return types;
541- }
542- }];
543542}
544543
545544// gridwise_attention_accel
546545def Rock_GridwiseAttentionAccelOp
547546 : Rock_Op<"gridwise_attention_accel",
548547 [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
549- RockFusionRoot, AttrSizedOperandSegments]>,
548+ DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
549+ RockFusionRoot, AttrSizedOperandSegments,
550+ ]>,
550551 Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
551552 MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
552553 MemRefRankOf<[F32, F16, BF16], [3]>:$values,
@@ -572,15 +573,6 @@ def Rock_GridwiseAttentionAccelOp
572573 `(` operands `)` (`features` `=` $features^)? `preSoftmaxOps` `=` $preSoftmaxBody attr-dict `:` type(operands)
573574 }];
574575 let hasVerifier = 1;
575-
576- // Return the type from this op that is needed to calculate GemmFeatures
577- let extraClassDeclaration = [{
578- SmallVector<::mlir::Type> getTypesForFeature() {
579- SmallVector<::mlir::Type> types = {getKeys().getType(),
580- getValues().getType()};
581- return types;
582- }
583- }];
584576}
585577
586578// Memory allocation on GPU memory hierachy.
@@ -1370,7 +1362,8 @@ defvar AccelResTypes = [VectorOfLengthAndType<[4, 8, 16, 32], [F32, I32, F16, BF
13701362
13711363// blockwise_gemm_accel
13721364def Rock_BlockwiseGemmAccelOp
1373- : Rock_Op<"blockwise_gemm_accel">,
1365+ : Rock_Op<"blockwise_gemm_accel", [DeclareOpInterfaceMethods<
1366+ RockGemmFeaturesInterface>]>,
13741367 Arguments<(ins MemRefOf<LdsBufferTypes>:$matrixA,
13751368 MemRefOf<LdsBufferTypes>:$matrixB, I32Attr:$inMPerThread,
13761369 I32Attr:$inNPerThread, UnitAttr:$rotateMWithK, UnitAttr:$rotateNWithK,
@@ -1393,14 +1386,6 @@ def Rock_BlockwiseGemmAccelOp
13931386 `:` type($matrixC) `+` `` `=` type($bufferA) `from` type($matrixA) `*`
13941387 type($bufferB) `from` type($matrixB)
13951388 }];
1396-
1397- // Return the type from this op that is needed to calculate GemmFeatures
1398- let extraClassDeclaration = [{
1399- SmallVector<::mlir::Type> getTypesForFeature() {
1400- SmallVector<::mlir::Type> types = {getMatrixA().getType()};
1401- return types;
1402- }
1403- }];
14041389}
14051390
14061391// threadwise_gemm
@@ -1426,7 +1411,8 @@ def Rock_ThreadwiseGemmOp:
14261411}
14271412// threadwise_accel_gemm
14281413def Rock_ThreadwiseAccelGemmOp
1429- : Rock_Op<"threadwise_accel_gemm">,
1414+ : Rock_Op<"threadwise_accel_gemm", [DeclareOpInterfaceMethods<
1415+ RockGemmFeaturesInterface>]>,
14301416 Arguments<(ins Arg<MemRefOf<NativeMemoryOpTypes>,
14311417 "source register view A", [MemRead]>:$matrixA,
14321418 Arg<MemRefOf<NativeMemoryOpTypes>,
@@ -1448,14 +1434,6 @@ def Rock_ThreadwiseAccelGemmOp
14481434 `:` type($matrixC) `+` `` `=` type($matrixA) `*` type($matrixB)
14491435 }];
14501436 let hasVerifier = 1;
1451-
1452- // Return the type from this op that is needed to calculate GemmFeatures
1453- let extraClassDeclaration = [{
1454- SmallVector<::mlir::Type> getTypesForFeature() {
1455- SmallVector<::mlir::Type> types = {getMatrixA().getType()};
1456- return types;
1457- }
1458- }];
14591437}
14601438
14611439// blockwise_broadcasting_reduction
0 commit comments