@@ -6,23 +6,61 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
6
6
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
7
7
8
8
//===----------------------------------------------------------------------===//
9
- // TritonGPU Attribute Definitions
9
+ // Traits and Interfaces
10
10
//===----------------------------------------------------------------------===//
11
- def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
12
- let cppNamespace = "::mlir::triton::gpu";
13
11
12
+ def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
13
+ def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
14
+
15
+ def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
16
+ let cppNamespace = "::mlir::triton::gpu";
17
+ let description = [{
18
+ Common trait for all TTGIR layouts.
19
+ }];
14
20
let methods = [
21
+ InterfaceMethod<"Get the shape of the CTAs per CGA.",
22
+ "SmallVector<unsigned>",
23
+ "getCTAsPerCGA", (ins), [{}], [{
24
+ return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA());
25
+ }]>,
26
+ InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
27
+ "SmallVector<unsigned>",
28
+ "getCTAOrder", (ins), [{}], [{
29
+ return llvm::to_vector($_attr.getCTALayout().getCTAOrder());
30
+ }]>,
31
+ InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
32
+ "SmallVector<unsigned>",
33
+ "getCTASplitNum", (ins), [{}], [{
34
+ return llvm::to_vector($_attr.getCTALayout().getCTASplitNum());
35
+ }]>,
36
+ InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{
37
+ return $_attr.getCTAOrder().size();
38
+ }]>
15
39
];
16
40
}
41
+ def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods<
42
+ LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>;
17
43
18
- def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
44
+ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
45
+ let cppNamespace = "::mlir::triton::gpu";
19
46
20
- def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
47
+ let description = [{
48
+ Common trait describing shared memory.
49
+ }];
50
+ let methods = [
51
+ InterfaceMethod<"Return the default alignment for the layout.",
52
+ "int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>,
53
+ ];
54
+ }
55
+ def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods<
56
+ SharedEncodingTrait, ["getAlignment"]>;
57
+
58
+ //===----------------------------------------------------------------------===//
59
+ // Base Attribute
60
+ //===----------------------------------------------------------------------===//
21
61
22
- class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
23
- Dialect dialect = TritonGPU_Dialect,
24
- string baseCppClass = "::mlir::Attribute">
25
- : AttrDef<dialect, name, !listconcat([TritonGPU_AttrTrait], traits), baseCppClass> {
62
+ class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = []>
63
+ : AttrDef<TritonGPU_Dialect, name, traits> {
26
64
27
65
let description = [{
28
66
TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
@@ -123,51 +161,17 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
123
161
CTAOrder.push_back(i);
124
162
return get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
125
163
}
126
- unsigned getRank() const {
127
- return getCTAOrder().size();
128
- }
164
+ unsigned getRank() const { return getCTAOrder().size(); }
129
165
}];
130
166
131
167
let genVerifyDecl = 1;
132
168
let skipDefaultBuilders = 1;
133
169
}
134
170
135
-
136
- def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
137
- let cppNamespace = "::mlir::triton::gpu";
138
- let description = [{
139
- Common trait for all TTGIR layouts.
140
- }];
141
- let methods = [
142
- InterfaceMethod<"Get the shape of the CTAs per CGA.",
143
- "SmallVector<unsigned>",
144
- "getCTAsPerCGA">,
145
- InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
146
- "SmallVector<unsigned>",
147
- "getCTAOrder">,
148
- InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
149
- "SmallVector<unsigned>",
150
- "getCTASplitNum">,
151
- ];
152
- }
153
-
154
171
//===----------------------------------------------------------------------===//
155
172
// Shared Layout Encoding
156
173
//===----------------------------------------------------------------------===//
157
174
158
- def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
159
- let cppNamespace = "::mlir::triton::gpu";
160
-
161
- let description = [{
162
- Common trait describing shared memory.
163
- }];
164
- let methods = [
165
- InterfaceMethod<"Return the default alignment for the layout.",
166
- "int32_t",
167
- "getAlignment">,
168
- ];
169
- }
170
-
171
175
def SwizzledSharedEncodingAttr
172
176
: TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
173
177
[SharedEncodingTrait, LayoutEncodingTrait]> {
@@ -359,13 +363,6 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
359
363
}]>,
360
364
];
361
365
362
- let extraClassDeclaration = extraBaseClassDeclaration # [{
363
- unsigned getRank() const { return getCTAOrder().size(); }
364
- int32_t getAlignment() const;
365
- SmallVector<unsigned> getCTAsPerCGA() const;
366
- SmallVector<unsigned> getCTAOrder() const;
367
- SmallVector<unsigned> getCTASplitNum() const;
368
- }];
369
366
let hasCustomAssemblyFormat = 1;
370
367
let genVerifyDecl = 1;
371
368
}
@@ -433,27 +430,19 @@ attributes too, for example,
433
430
];
434
431
435
432
let extraClassDeclaration = extraBaseClassDeclaration # [{
436
- unsigned getRank() const { return getOrder().size(); }
437
- int32_t getAlignment() const { return 16; }
438
-
439
433
unsigned getMinInterval() const {
440
434
return *llvm::min_element(getIntervals());
441
435
}
442
436
443
437
// Returns the total number of elements including padding given the input
444
438
// tensor shape.
445
439
int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
446
-
447
- SmallVector<unsigned> getCTAsPerCGA() const;
448
- SmallVector<unsigned> getCTAOrder() const;
449
- SmallVector<unsigned> getCTASplitNum() const;
450
440
}];
451
441
let hasCustomAssemblyFormat = 1;
452
442
let genVerifyDecl = 1;
453
443
}
454
444
455
- def NVMMASharedEncodingAttr :
456
- TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
445
+ def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> {
457
446
let mnemonic = "nvmma_shared";
458
447
459
448
let description = [{
@@ -513,11 +502,6 @@ def NVMMASharedEncodingAttr :
513
502
];
514
503
515
504
let extraClassDeclaration = extraBaseClassDeclaration # [{
516
- unsigned getRank() const { return getCTAOrder().size(); }
517
- int32_t getAlignment() const;
518
- SmallVector<unsigned> getCTAsPerCGA() const;
519
- SmallVector<unsigned> getCTAOrder() const;
520
- SmallVector<unsigned> getCTASplitNum() const;
521
505
int getPerPhase() const;
522
506
int getMaxPhase() const;
523
507
int getVec() const;
@@ -619,20 +603,14 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
619
603
"CTALayoutAttr":$CTALayout
620
604
);
621
605
622
- let extraClassDeclaration = extraBaseClassDeclaration # [{
623
- unsigned getRank() const { return getCTAOrder().size(); }
624
- int32_t getAlignment() const;
625
- SmallVector<unsigned> getCTAsPerCGA() const;
626
- SmallVector<unsigned> getCTAOrder() const;
627
- SmallVector<unsigned> getCTASplitNum() const;
628
- }];
629
606
let hasCustomAssemblyFormat = 1;
630
607
}
631
608
632
609
633
610
//===----------------------------------------------------------------------===//
634
611
// Distributed Layout Encoding
635
612
//===----------------------------------------------------------------------===//
613
+
636
614
def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> {
637
615
let cppNamespace = "::mlir::triton::gpu";
638
616
@@ -681,9 +659,8 @@ We call each individual tile "rep".
681
659
];
682
660
}
683
661
684
- class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = [],
685
- Dialect dialect = TritonGPU_Dialect>
686
- : TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits), dialect> {
662
+ class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = []>
663
+ : TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits)> {
687
664
688
665
let description = [{
689
666
Distributed encodings have a layout function L that is entirely characterized
@@ -719,12 +696,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
719
696
}];
720
697
721
698
code extraDistributedDeclaration = extraBaseClassDeclaration # [{
722
- unsigned getRank() const { return getCTAOrder().size(); }
723
699
// Implemented in subclasses
724
700
SmallVector<unsigned> getRepOrder() const;
725
- SmallVector<unsigned> getCTAsPerCGA() const;
726
- SmallVector<unsigned> getCTAOrder() const;
727
- SmallVector<unsigned> getCTASplitNum() const;
728
701
729
702
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
730
703
}];
@@ -739,7 +712,7 @@ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
739
712
let cppAccessorType = "const LinearLayout &";
740
713
}
741
714
742
- def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
715
+ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods] > {
743
716
let mnemonic = "linear";
744
717
745
718
let description = [{
@@ -1376,7 +1349,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
1376
1349
let hasCustomAssemblyFormat = 1;
1377
1350
}
1378
1351
1379
- def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
1352
+ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods] > {
1380
1353
let mnemonic = "slice";
1381
1354
1382
1355
let description = [{
@@ -1419,9 +1392,10 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
1419
1392
}];
1420
1393
1421
1394
let hasCustomAssemblyFormat = 1;
1395
+ let genVerifyDecl = 1;
1422
1396
}
1423
1397
1424
- def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> {
1398
+ def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods] > {
1425
1399
let mnemonic = "dot_op";
1426
1400
1427
1401
let description = [{
0 commit comments