@@ -6,23 +6,61 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
66include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
77
88//===----------------------------------------------------------------------===//
9- // TritonGPU Attribute Definitions
9+ // Traits and Interfaces
1010//===----------------------------------------------------------------------===//
11- def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
12- let cppNamespace = "::mlir::triton::gpu";
1311
12+ def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
13+ def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
14+
15+ def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
16+ let cppNamespace = "::mlir::triton::gpu";
17+ let description = [{
18+ Common trait for all TTGIR layouts.
19+ }];
1420 let methods = [
21+ InterfaceMethod<"Get the shape of the CTAs per CGA.",
22+ "SmallVector<unsigned>",
23+ "getCTAsPerCGA", (ins), [{}], [{
24+ return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA());
25+ }]>,
26+ InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
27+ "SmallVector<unsigned>",
28+ "getCTAOrder", (ins), [{}], [{
29+ return llvm::to_vector($_attr.getCTALayout().getCTAOrder());
30+ }]>,
31+ InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
32+ "SmallVector<unsigned>",
33+ "getCTASplitNum", (ins), [{}], [{
34+ return llvm::to_vector($_attr.getCTALayout().getCTASplitNum());
35+ }]>,
36+ InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{
37+ return $_attr.getCTAOrder().size();
38+ }]>
1539 ];
1640}
41+ def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods<
42+ LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>;
1743
18- def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
44+ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
45+ let cppNamespace = "::mlir::triton::gpu";
1946
20- def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
47+ let description = [{
48+ Common trait describing shared memory.
49+ }];
50+ let methods = [
51+ InterfaceMethod<"Return the default alignment for the layout.",
52+ "int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>,
53+ ];
54+ }
55+ def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods<
56+ SharedEncodingTrait, ["getAlignment"]>;
57+
58+ //===----------------------------------------------------------------------===//
59+ // Base Attribute
60+ //===----------------------------------------------------------------------===//
2161
22- class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
23- Dialect dialect = TritonGPU_Dialect,
24- string baseCppClass = "::mlir::Attribute">
25- : AttrDef<dialect, name, !listconcat([TritonGPU_AttrTrait], traits), baseCppClass> {
62+ class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [], Dialect dialect = TritonGPU_Dialect>
63+ : AttrDef<dialect, name, traits> {
2664
2765 let description = [{
2866TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
@@ -123,51 +161,17 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
123161 CTAOrder.push_back(i);
124162 return get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
125163 }
126- unsigned getRank() const {
127- return getCTAOrder().size();
128- }
164+ unsigned getRank() const { return getCTAOrder().size(); }
129165 }];
130166
131167 let genVerifyDecl = 1;
132168 let skipDefaultBuilders = 1;
133169}
134170
135-
136- def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
137- let cppNamespace = "::mlir::triton::gpu";
138- let description = [{
139- Common trait for all TTGIR layouts.
140- }];
141- let methods = [
142- InterfaceMethod<"Get the shape of the CTAs per CGA.",
143- "SmallVector<unsigned>",
144- "getCTAsPerCGA">,
145- InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
146- "SmallVector<unsigned>",
147- "getCTAOrder">,
148- InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
149- "SmallVector<unsigned>",
150- "getCTASplitNum">,
151- ];
152- }
153-
154171//===----------------------------------------------------------------------===//
155172// Shared Layout Encoding
156173//===----------------------------------------------------------------------===//
157174
158- def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
159- let cppNamespace = "::mlir::triton::gpu";
160-
161- let description = [{
162- Common trait describing shared memory.
163- }];
164- let methods = [
165- InterfaceMethod<"Return the default alignment for the layout.",
166- "int32_t",
167- "getAlignment">,
168- ];
169- }
170-
171175def SwizzledSharedEncodingAttr
172176 : TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
173177 [SharedEncodingTrait, LayoutEncodingTrait]> {
@@ -359,13 +363,6 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
359363 }]>,
360364 ];
361365
362- let extraClassDeclaration = extraBaseClassDeclaration # [{
363- unsigned getRank() const { return getCTAOrder().size(); }
364- int32_t getAlignment() const;
365- SmallVector<unsigned> getCTAsPerCGA() const;
366- SmallVector<unsigned> getCTAOrder() const;
367- SmallVector<unsigned> getCTASplitNum() const;
368- }];
369366 let hasCustomAssemblyFormat = 1;
370367 let genVerifyDecl = 1;
371368}
@@ -433,27 +430,19 @@ attributes too, for example,
433430 ];
434431
435432 let extraClassDeclaration = extraBaseClassDeclaration # [{
436- unsigned getRank() const { return getOrder().size(); }
437- int32_t getAlignment() const { return 16; }
438-
439433 unsigned getMinInterval() const {
440434 return *llvm::min_element(getIntervals());
441435 }
442436
443437 // Returns the total number of elements including padding given the input
444438 // tensor shape.
445439 int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
446-
447- SmallVector<unsigned> getCTAsPerCGA() const;
448- SmallVector<unsigned> getCTAOrder() const;
449- SmallVector<unsigned> getCTASplitNum() const;
450440 }];
451441 let hasCustomAssemblyFormat = 1;
452442 let genVerifyDecl = 1;
453443}
454444
455- def NVMMASharedEncodingAttr :
456- TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
445+ def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> {
457446 let mnemonic = "nvmma_shared";
458447
459448 let description = [{
@@ -513,11 +502,6 @@ def NVMMASharedEncodingAttr :
513502 ];
514503
515504 let extraClassDeclaration = extraBaseClassDeclaration # [{
516- unsigned getRank() const { return getCTAOrder().size(); }
517- int32_t getAlignment() const;
518- SmallVector<unsigned> getCTAsPerCGA() const;
519- SmallVector<unsigned> getCTAOrder() const;
520- SmallVector<unsigned> getCTASplitNum() const;
521505 int getPerPhase() const;
522506 int getMaxPhase() const;
523507 int getVec() const;
@@ -619,20 +603,14 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
619603 "CTALayoutAttr":$CTALayout
620604 );
621605
622- let extraClassDeclaration = extraBaseClassDeclaration # [{
623- unsigned getRank() const { return getCTAOrder().size(); }
624- int32_t getAlignment() const;
625- SmallVector<unsigned> getCTAsPerCGA() const;
626- SmallVector<unsigned> getCTAOrder() const;
627- SmallVector<unsigned> getCTASplitNum() const;
628- }];
629606 let hasCustomAssemblyFormat = 1;
630607}
631608
632609
633610//===----------------------------------------------------------------------===//
634611// Distributed Layout Encoding
635612//===----------------------------------------------------------------------===//
613+
636614def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> {
637615 let cppNamespace = "::mlir::triton::gpu";
638616
@@ -719,12 +697,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
719697 }];
720698
721699 code extraDistributedDeclaration = extraBaseClassDeclaration # [{
722- unsigned getRank() const { return getCTAOrder().size(); }
723700 // Implemented in subclasses
724701 SmallVector<unsigned> getRepOrder() const;
725- SmallVector<unsigned> getCTAsPerCGA() const;
726- SmallVector<unsigned> getCTAOrder() const;
727- SmallVector<unsigned> getCTASplitNum() const;
728702
729703 LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
730704 }];
@@ -739,7 +713,7 @@ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
739713 let cppAccessorType = "const LinearLayout &";
740714}
741715
742- def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
716+ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods] > {
743717 let mnemonic = "linear";
744718
745719 let description = [{
@@ -1376,7 +1350,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
13761350 let hasCustomAssemblyFormat = 1;
13771351}
13781352
1379- def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
1353+ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods] > {
13801354 let mnemonic = "slice";
13811355
13821356 let description = [{
@@ -1419,9 +1393,10 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
14191393 }];
14201394
14211395 let hasCustomAssemblyFormat = 1;
1396+ let genVerifyDecl = 1;
14221397}
14231398
1424- def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> {
1399+ def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods] > {
14251400 let mnemonic = "dot_op";
14261401
14271402 let description = [{
0 commit comments