@@ -8,16 +8,13 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
88include "mlir/IR/OpBase.td"
99include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
1010include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
11- include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
12- include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
1311include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface
1412include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
1513include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
1614include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
17- include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
18- include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
1915include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
2016include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
17+ include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
2118
2219
2320//
@@ -44,8 +41,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
4441def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
4542 SameOperandsAndResultShape,
4643 SameOperandsAndResultEncoding,
47- Pure,
48- /*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
44+ Pure]> {
4945 let summary = "Cast int64 to pointer";
5046
5147 let arguments = (ins TT_I64Like:$src);
@@ -58,8 +54,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
5854def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
5955 SameOperandsAndResultShape,
6056 SameOperandsAndResultEncoding,
61- Pure,
62- /*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
57+ Pure]> {
6358 let summary = "Cast pointer to int64";
6459
6560 let arguments = (ins TT_PtrLike:$src);
@@ -73,8 +68,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
7368def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
7469 SameOperandsAndResultShape,
7570 SameOperandsAndResultEncoding,
76- Pure,
77- /*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
71+ Pure]> {
7872 let summary = "Cast between types of the same bitwidth";
7973
8074 let arguments = (ins TT_Type:$src);
@@ -89,8 +83,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
8983def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
9084 SameOperandsAndResultShape,
9185 SameOperandsAndResultEncoding,
92- Pure,
93- /*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
86+ Pure]> {
9487 let summary = "Floating point casting for custom types";
9588
9689 let description = [{
@@ -118,8 +111,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
118111//
119112
120113def TT_ClampFOp : TT_Op<"clampf", [Elementwise,
121- SameOperandsAndResultType,
122- Pure]> {
114+ SameOperandsAndResultType,
115+ Pure]> {
123116 let summary = "Clamp operation for floating point types";
124117
125118 let description = [{
@@ -149,8 +142,8 @@ def TT_ClampFOp : TT_Op<"clampf", [Elementwise,
149142//
150143
151144def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise,
152- SameOperandsAndResultType,
153- Pure]> {
145+ SameOperandsAndResultType,
146+ Pure]> {
154147 let summary = "Precise sqrt for floating point types";
155148
156149 let description = [{
@@ -165,8 +158,8 @@ def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise,
165158}
166159
167160def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise,
168- SameOperandsAndResultType,
169- Pure]> {
161+ SameOperandsAndResultType,
162+ Pure]> {
170163 let summary = "Precise div for floating point types";
171164
172165 let description = [{
@@ -181,8 +174,8 @@ def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise,
181174}
182175
183176def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise,
184- SameOperandsAndResultType,
185- Pure]> {
177+ SameOperandsAndResultType,
178+ Pure]> {
186179 let summary = "Most significant N bits of the 2N-bit product of two integers";
187180
188181 let description = [{
@@ -200,12 +193,12 @@ def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise,
200193// Pointer Arith Ops
201194//
202195def TT_AddPtrOp : TT_Op<"addptr",
203- [Pure,
204- Elementwise,
205- SameOperandsAndResultShape,
206- SameOperandsAndResultEncoding,
207- TypesMatchWith<"result type matches ptr type",
208- "result", "ptr", "$_self">]> {
196+ [Pure,
197+ Elementwise,
198+ SameOperandsAndResultShape,
199+ SameOperandsAndResultEncoding,
200+ TypesMatchWith<"result type matches ptr type",
201+ "result", "ptr", "$_self">]> {
209202 let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);
210203
211204 let results = (outs TT_PtrLike:$result);
@@ -546,6 +539,7 @@ def TT_SplitOp : TT_Op<"split", [
546539}
547540
548541def TT_TransOp : TT_Op<"trans", [Pure,
542+ TransposeOpInterface,
549543 DeclareOpInterfaceMethods<InferTypeOpInterface>,
550544 SameOperandsAndResultElementType]> {
551545
@@ -579,16 +573,15 @@ def TT_TransOp : TT_Op<"trans", [Pure,
579573 }];
580574
581575 let arguments = (
582- ins TT_TensorOrMemDesc :$src,
576+ ins TT_Tensor :$src,
583577 DenseI32ArrayAttr:$order
584578 );
585579
586- let results = (outs TT_TensorOrMemDesc :$result);
580+ let results = (outs TT_Tensor :$result);
587581
588582 let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
589583
590584 let hasFolder = 1;
591- let hasVerifier = 1;
592585}
593586
594587//
@@ -677,10 +670,10 @@ def TT_DotOp : TT_Op<"dot", [Pure,
677670// DotScaled Op
678671//
679672def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
680- AttrSizedOperandSegments,
681- DotLike,
682- TypesMatchWith<"result's type matches accumulator's type",
683- "d", "c", "$_self">]> {
673+ AttrSizedOperandSegments,
674+ DotLike,
675+ TypesMatchWith<"result's type matches accumulator's type",
676+ "d", "c", "$_self">]> {
684677 let summary = "dot_scaled";
685678
686679 let description = [{
@@ -783,10 +776,10 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
783776// External Elementwise op
784777//
785778def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
786- SameOperandsAndResultEncoding,
787- SameVariadicOperandSize,
788- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
789- ConditionallySpeculatable]> {
779+ SameOperandsAndResultEncoding,
780+ SameVariadicOperandSize,
781+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
782+ ConditionallySpeculatable]> {
790783
791784 let description = [{
792785 call an external function $symbol implemented in $libpath/$libname with $args
0 commit comments