@@ -6,23 +6,61 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
66include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
77
88//===----------------------------------------------------------------------===//
9- // TritonGPU Attribute Definitions
9+ // Traits and Interfaces
1010//===----------------------------------------------------------------------===//
11- def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
12- let cppNamespace = "::mlir::triton::gpu";
1311
12+ def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
13+ def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
14+
15+ def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
16+ let cppNamespace = "::mlir::triton::gpu";
17+ let description = [{
18+ Common trait for all TTGIR layouts.
19+ }];
1420 let methods = [
21+ InterfaceMethod<"Get the shape of the CTAs per CGA.",
22+ "SmallVector<unsigned>",
23+ "getCTAsPerCGA", (ins), [{}], [{
24+ return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA());
25+ }]>,
26+ InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
27+ "SmallVector<unsigned>",
28+ "getCTAOrder", (ins), [{}], [{
29+ return llvm::to_vector($_attr.getCTALayout().getCTAOrder());
30+ }]>,
31+ InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
32+ "SmallVector<unsigned>",
33+ "getCTASplitNum", (ins), [{}], [{
34+ return llvm::to_vector($_attr.getCTALayout().getCTASplitNum());
35+ }]>,
36+ InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{
37+ return $_attr.getCTAOrder().size();
38+ }]>
1539 ];
1640}
41+ def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods<
42+ LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>;
1743
18- def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
44+ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
45+ let cppNamespace = "::mlir::triton::gpu";
1946
20- def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
47+ let description = [{
48+ Common trait describing shared memory.
49+ }];
50+ let methods = [
51+ InterfaceMethod<"Return the default alignment for the layout.",
52+ "int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>,
53+ ];
54+ }
55+ def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods<
56+ SharedEncodingTrait, ["getAlignment"]>;
57+
58+ //===----------------------------------------------------------------------===//
59+ // Base Attribute
60+ //===----------------------------------------------------------------------===//
2161
22- class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
23- Dialect dialect = TritonGPU_Dialect,
24- string baseCppClass = "::mlir::Attribute">
25- : AttrDef<dialect, name, !listconcat([TritonGPU_AttrTrait], traits), baseCppClass> {
62+ class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = []>
63+ : AttrDef<TritonGPU_Dialect, name, traits> {
2664
2765 let description = [{
2866TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
@@ -123,51 +161,17 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
123161 CTAOrder.push_back(i);
124162 return get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
125163 }
126- unsigned getRank() const {
127- return getCTAOrder().size();
128- }
164+ unsigned getRank() const { return getCTAOrder().size(); }
129165 }];
130166
131167 let genVerifyDecl = 1;
132168 let skipDefaultBuilders = 1;
133169}
134170
135-
136- def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
137- let cppNamespace = "::mlir::triton::gpu";
138- let description = [{
139- Common trait for all TTGIR layouts.
140- }];
141- let methods = [
142- InterfaceMethod<"Get the shape of the CTAs per CGA.",
143- "SmallVector<unsigned>",
144- "getCTAsPerCGA">,
145- InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
146- "SmallVector<unsigned>",
147- "getCTAOrder">,
148- InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
149- "SmallVector<unsigned>",
150- "getCTASplitNum">,
151- ];
152- }
153-
154171//===----------------------------------------------------------------------===//
155172// Shared Layout Encoding
156173//===----------------------------------------------------------------------===//
157174
158- def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
159- let cppNamespace = "::mlir::triton::gpu";
160-
161- let description = [{
162- Common trait describing shared memory.
163- }];
164- let methods = [
165- InterfaceMethod<"Return the default alignment for the layout.",
166- "int32_t",
167- "getAlignment">,
168- ];
169- }
170-
171175def SwizzledSharedEncodingAttr
172176 : TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
173177 [SharedEncodingTrait, LayoutEncodingTrait]> {
@@ -359,13 +363,6 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
359363 }]>,
360364 ];
361365
362- let extraClassDeclaration = extraBaseClassDeclaration # [{
363- unsigned getRank() const { return getCTAOrder().size(); }
364- int32_t getAlignment() const;
365- SmallVector<unsigned> getCTAsPerCGA() const;
366- SmallVector<unsigned> getCTAOrder() const;
367- SmallVector<unsigned> getCTASplitNum() const;
368- }];
369366 let hasCustomAssemblyFormat = 1;
370367 let genVerifyDecl = 1;
371368}
@@ -433,27 +430,19 @@ attributes too, for example,
433430 ];
434431
435432 let extraClassDeclaration = extraBaseClassDeclaration # [{
436- unsigned getRank() const { return getOrder().size(); }
437- int32_t getAlignment() const { return 16; }
438-
439433 unsigned getMinInterval() const {
440434 return *llvm::min_element(getIntervals());
441435 }
442436
443437 // Returns the total number of elements including padding given the input
444438 // tensor shape.
445439 int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
446-
447- SmallVector<unsigned> getCTAsPerCGA() const;
448- SmallVector<unsigned> getCTAOrder() const;
449- SmallVector<unsigned> getCTASplitNum() const;
450440 }];
451441 let hasCustomAssemblyFormat = 1;
452442 let genVerifyDecl = 1;
453443}
454444
455- def NVMMASharedEncodingAttr :
456- TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
445+ def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> {
457446 let mnemonic = "nvmma_shared";
458447
459448 let description = [{
@@ -513,11 +502,6 @@ def NVMMASharedEncodingAttr :
513502 ];
514503
515504 let extraClassDeclaration = extraBaseClassDeclaration # [{
516- unsigned getRank() const { return getCTAOrder().size(); }
517- int32_t getAlignment() const;
518- SmallVector<unsigned> getCTAsPerCGA() const;
519- SmallVector<unsigned> getCTAOrder() const;
520- SmallVector<unsigned> getCTASplitNum() const;
521505 int getPerPhase() const;
522506 int getMaxPhase() const;
523507 int getVec() const;
@@ -619,20 +603,14 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
619603 "CTALayoutAttr":$CTALayout
620604 );
621605
622- let extraClassDeclaration = extraBaseClassDeclaration # [{
623- unsigned getRank() const { return getCTAOrder().size(); }
624- int32_t getAlignment() const;
625- SmallVector<unsigned> getCTAsPerCGA() const;
626- SmallVector<unsigned> getCTAOrder() const;
627- SmallVector<unsigned> getCTASplitNum() const;
628- }];
629606 let hasCustomAssemblyFormat = 1;
630607}
631608
632609
633610//===----------------------------------------------------------------------===//
634611// Distributed Layout Encoding
635612//===----------------------------------------------------------------------===//
613+
636614def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> {
637615 let cppNamespace = "::mlir::triton::gpu";
638616
@@ -681,9 +659,8 @@ We call each individual tile "rep".
681659 ];
682660}
683661
684- class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = [],
685- Dialect dialect = TritonGPU_Dialect>
686- : TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits), dialect> {
662+ class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = []>
663+ : TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits)> {
687664
688665 let description = [{
689666Distributed encodings have a layout function L that is entirely characterized
@@ -719,12 +696,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
719696 }];
720697
721698 code extraDistributedDeclaration = extraBaseClassDeclaration # [{
722- unsigned getRank() const { return getCTAOrder().size(); }
723699 // Implemented in subclasses
724700 SmallVector<unsigned> getRepOrder() const;
725- SmallVector<unsigned> getCTAsPerCGA() const;
726- SmallVector<unsigned> getCTAOrder() const;
727- SmallVector<unsigned> getCTASplitNum() const;
728701
729702 LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
730703 }];
@@ -739,7 +712,7 @@ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
739712 let cppAccessorType = "const LinearLayout &";
740713}
741714
742- def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
715+ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods] > {
743716 let mnemonic = "linear";
744717
745718 let description = [{
@@ -1376,7 +1349,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
13761349 let hasCustomAssemblyFormat = 1;
13771350}
13781351
1379- def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
1352+ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods] > {
13801353 let mnemonic = "slice";
13811354
13821355 let description = [{
@@ -1419,9 +1392,10 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
14191392 }];
14201393
14211394 let hasCustomAssemblyFormat = 1;
1395+ let genVerifyDecl = 1;
14221396}
14231397
1424- def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> {
1398+ def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods] > {
14251399 let mnemonic = "dot_op";
14261400
14271401 let description = [{
0 commit comments