@@ -6,61 +6,23 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
6
6
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
7
7
8
8
//===----------------------------------------------------------------------===//
9
- // Traits and Interfaces
9
+ // TritonGPU Attribute Definitions
10
10
//===----------------------------------------------------------------------===//
11
-
12
- def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
13
- def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
14
-
15
- def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
16
- let cppNamespace = "::mlir::triton::gpu";
17
- let description = [{
18
- Common trait for all TTGIR layouts.
19
- }];
20
- let methods = [
21
- InterfaceMethod<"Get the shape of the CTAs per CGA.",
22
- "SmallVector<unsigned>",
23
- "getCTAsPerCGA", (ins), [{}], [{
24
- return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA());
25
- }]>,
26
- InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
27
- "SmallVector<unsigned>",
28
- "getCTAOrder", (ins), [{}], [{
29
- return llvm::to_vector($_attr.getCTALayout().getCTAOrder());
30
- }]>,
31
- InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
32
- "SmallVector<unsigned>",
33
- "getCTASplitNum", (ins), [{}], [{
34
- return llvm::to_vector($_attr.getCTALayout().getCTASplitNum());
35
- }]>,
36
- InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{
37
- return $_attr.getCTAOrder().size();
38
- }]>
39
- ];
40
- }
41
- def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods<
42
- LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>;
43
-
44
- def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
11
+ def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
45
12
let cppNamespace = "::mlir::triton::gpu";
46
13
47
- let description = [{
48
- Common trait describing shared memory.
49
- }];
50
14
let methods = [
51
- InterfaceMethod<"Return the default alignment for the layout.",
52
- "int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>,
53
15
];
54
16
}
55
- def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods<
56
- SharedEncodingTrait, ["getAlignment"]>;
57
17
58
- //===----------------------------------------------------------------------===//
59
- // Base Attribute
60
- //===----------------------------------------------------------------------===//
18
+ def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
19
+
20
+ def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
61
21
62
- class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = []>
63
- : AttrDef<TritonGPU_Dialect, name, traits> {
22
+ class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
23
+ Dialect dialect = TritonGPU_Dialect,
24
+ string baseCppClass = "::mlir::Attribute">
25
+ : AttrDef<dialect, name, !listconcat([TritonGPU_AttrTrait], traits), baseCppClass> {
64
26
65
27
let description = [{
66
28
TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
@@ -161,17 +123,51 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
161
123
CTAOrder.push_back(i);
162
124
return get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
163
125
}
164
- unsigned getRank() const { return getCTAOrder().size(); }
126
+ unsigned getRank() const {
127
+ return getCTAOrder().size();
128
+ }
165
129
}];
166
130
167
131
let genVerifyDecl = 1;
168
132
let skipDefaultBuilders = 1;
169
133
}
170
134
135
+
136
+ def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
137
+ let cppNamespace = "::mlir::triton::gpu";
138
+ let description = [{
139
+ Common trait for all TTGIR layouts.
140
+ }];
141
+ let methods = [
142
+ InterfaceMethod<"Get the shape of the CTAs per CGA.",
143
+ "SmallVector<unsigned>",
144
+ "getCTAsPerCGA">,
145
+ InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
146
+ "SmallVector<unsigned>",
147
+ "getCTAOrder">,
148
+ InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
149
+ "SmallVector<unsigned>",
150
+ "getCTASplitNum">,
151
+ ];
152
+ }
153
+
171
154
//===----------------------------------------------------------------------===//
172
155
// Shared Layout Encoding
173
156
//===----------------------------------------------------------------------===//
174
157
158
+ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
159
+ let cppNamespace = "::mlir::triton::gpu";
160
+
161
+ let description = [{
162
+ Common trait describing shared memory.
163
+ }];
164
+ let methods = [
165
+ InterfaceMethod<"Return the default alignment for the layout.",
166
+ "int32_t",
167
+ "getAlignment">,
168
+ ];
169
+ }
170
+
175
171
def SwizzledSharedEncodingAttr
176
172
: TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
177
173
[SharedEncodingTrait, LayoutEncodingTrait]> {
@@ -363,6 +359,13 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
363
359
}]>,
364
360
];
365
361
362
+ let extraClassDeclaration = extraBaseClassDeclaration # [{
363
+ unsigned getRank() const { return getCTAOrder().size(); }
364
+ int32_t getAlignment() const;
365
+ SmallVector<unsigned> getCTAsPerCGA() const;
366
+ SmallVector<unsigned> getCTAOrder() const;
367
+ SmallVector<unsigned> getCTASplitNum() const;
368
+ }];
366
369
let hasCustomAssemblyFormat = 1;
367
370
let genVerifyDecl = 1;
368
371
}
@@ -430,19 +433,27 @@ attributes too, for example,
430
433
];
431
434
432
435
let extraClassDeclaration = extraBaseClassDeclaration # [{
436
+ unsigned getRank() const { return getOrder().size(); }
437
+ int32_t getAlignment() const { return 16; }
438
+
433
439
unsigned getMinInterval() const {
434
440
return *llvm::min_element(getIntervals());
435
441
}
436
442
437
443
// Returns the total number of elements including padding given the input
438
444
// tensor shape.
439
445
int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
446
+
447
+ SmallVector<unsigned> getCTAsPerCGA() const;
448
+ SmallVector<unsigned> getCTAOrder() const;
449
+ SmallVector<unsigned> getCTASplitNum() const;
440
450
}];
441
451
let hasCustomAssemblyFormat = 1;
442
452
let genVerifyDecl = 1;
443
453
}
444
454
445
- def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> {
455
+ def NVMMASharedEncodingAttr :
456
+ TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
446
457
let mnemonic = "nvmma_shared";
447
458
448
459
let description = [{
@@ -502,6 +513,11 @@ def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_share
502
513
];
503
514
504
515
let extraClassDeclaration = extraBaseClassDeclaration # [{
516
+ unsigned getRank() const { return getCTAOrder().size(); }
517
+ int32_t getAlignment() const;
518
+ SmallVector<unsigned> getCTAsPerCGA() const;
519
+ SmallVector<unsigned> getCTAOrder() const;
520
+ SmallVector<unsigned> getCTASplitNum() const;
505
521
int getPerPhase() const;
506
522
int getMaxPhase() const;
507
523
int getVec() const;
@@ -603,14 +619,20 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
603
619
"CTALayoutAttr":$CTALayout
604
620
);
605
621
622
+ let extraClassDeclaration = extraBaseClassDeclaration # [{
623
+ unsigned getRank() const { return getCTAOrder().size(); }
624
+ int32_t getAlignment() const;
625
+ SmallVector<unsigned> getCTAsPerCGA() const;
626
+ SmallVector<unsigned> getCTAOrder() const;
627
+ SmallVector<unsigned> getCTASplitNum() const;
628
+ }];
606
629
let hasCustomAssemblyFormat = 1;
607
630
}
608
631
609
632
610
633
//===----------------------------------------------------------------------===//
611
634
// Distributed Layout Encoding
612
635
//===----------------------------------------------------------------------===//
613
-
614
636
def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> {
615
637
let cppNamespace = "::mlir::triton::gpu";
616
638
@@ -659,8 +681,9 @@ We call each individual tile "rep".
659
681
];
660
682
}
661
683
662
- class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = []>
663
- : TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits)> {
684
+ class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = [],
685
+ Dialect dialect = TritonGPU_Dialect>
686
+ : TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits), dialect> {
664
687
665
688
let description = [{
666
689
Distributed encodings have a layout function L that is entirely characterized
@@ -696,8 +719,12 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
696
719
}];
697
720
698
721
code extraDistributedDeclaration = extraBaseClassDeclaration # [{
722
+ unsigned getRank() const { return getCTAOrder().size(); }
699
723
// Implemented in subclasses
700
724
SmallVector<unsigned> getRepOrder() const;
725
+ SmallVector<unsigned> getCTAsPerCGA() const;
726
+ SmallVector<unsigned> getCTAOrder() const;
727
+ SmallVector<unsigned> getCTASplitNum() const;
701
728
702
729
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
703
730
}];
@@ -712,7 +739,7 @@ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
712
739
let cppAccessorType = "const LinearLayout &";
713
740
}
714
741
715
- def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods] > {
742
+ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
716
743
let mnemonic = "linear";
717
744
718
745
let description = [{
@@ -1349,7 +1376,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
1349
1376
let hasCustomAssemblyFormat = 1;
1350
1377
}
1351
1378
1352
- def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods] > {
1379
+ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
1353
1380
let mnemonic = "slice";
1354
1381
1355
1382
let description = [{
@@ -1392,10 +1419,9 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [
1392
1419
}];
1393
1420
1394
1421
let hasCustomAssemblyFormat = 1;
1395
- let genVerifyDecl = 1;
1396
1422
}
1397
1423
1398
- def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods] > {
1424
+ def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> {
1399
1425
let mnemonic = "dot_op";
1400
1426
1401
1427
let description = [{
0 commit comments