@@ -27,6 +27,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td"
2727include "triton/Dialect/Triton/IR/TritonTypes.td"
2828include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
2929include "triton/Dialect/Triton/IR/TritonInterfaces.td"
30+ include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
3031include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
3132include "triton/Dialect/TritonGPU/IR/TritonGPUTypeInterfaces.td"
3233include "mlir/IR/OpBase.td"
@@ -71,7 +72,7 @@ def TTNG_ClusterWaitOp : TTNG_Op<"cluster_wait", []> {
7172//
7273def TTNG_WarpGroupDotOp : TTNG_Op<"warp_group_dot", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
7374 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
74- DotLike ,
75+ DeclareOpInterfaceMethods<DotOpInterface> ,
7576 TypesMatchWith<"result's type matches accumulator's type",
7677 "d", "c", "$_self">]> {
7778 let summary = "warp group dot";
@@ -325,7 +326,7 @@ def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
325326 let assemblyFormat = "attr-dict";
326327}
327328
328- def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DotLike ]> {
329+ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface> ]> {
329330 let summary = "block level op mapping to tensorcore gen5 mma";
330331
331332 let description = [{
@@ -343,11 +344,12 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryE
343344 I1:$pred,
344345 Optional<TTG_MemDescType>:$barrier,
345346 OptionalAttr<UnitAttr>:$two_ctas);
347+
346348 // TODO: improve printing format.
347349 let assemblyFormat = "$a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
348350}
349351
350- def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DotLike ]> {
352+ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<DotOpInterface> ]> {
351353 let summary = "block level op mapping to tensorcore gen5 mma";
352354
353355 let description = [{
@@ -366,6 +368,7 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMe
366368 I1:$useD,
367369 I1:$pred,
368370 Optional<TTG_MemDescType>:$barrier);
371+
369372 // TODO: improve printing format.
370373 let assemblyFormat = "$a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)";
371374}
0 commit comments