Skip to content

Commit 473e293

Browse files
chengjunluwhitneywhtsang
authored andcommitted
Revert "Revert "[Dialect] Layout attr cleanup and tighten invariants (#7714)""
This reverts commit afe1e9d Signed-off-by: Lu,Chengjun <[email protected]>
1 parent b0acd14 commit 473e293

File tree

19 files changed

+512
-495
lines changed

19 files changed

+512
-495
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 56 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,61 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
66
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
77

88
//===----------------------------------------------------------------------===//
9-
// TritonGPU Attribute Definitions
9+
// Traits and Interfaces
1010
//===----------------------------------------------------------------------===//
11-
def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
12-
let cppNamespace = "::mlir::triton::gpu";
1311

12+
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
13+
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
14+
15+
def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
16+
let cppNamespace = "::mlir::triton::gpu";
17+
let description = [{
18+
Common trait for all TTGIR layouts.
19+
}];
1420
let methods = [
21+
InterfaceMethod<"Get the shape of the CTAs per CGA.",
22+
"SmallVector<unsigned>",
23+
"getCTAsPerCGA", (ins), [{}], [{
24+
return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA());
25+
}]>,
26+
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
27+
"SmallVector<unsigned>",
28+
"getCTAOrder", (ins), [{}], [{
29+
return llvm::to_vector($_attr.getCTALayout().getCTAOrder());
30+
}]>,
31+
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
32+
"SmallVector<unsigned>",
33+
"getCTASplitNum", (ins), [{}], [{
34+
return llvm::to_vector($_attr.getCTALayout().getCTASplitNum());
35+
}]>,
36+
InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{
37+
return $_attr.getCTAOrder().size();
38+
}]>
1539
];
1640
}
41+
def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods<
42+
LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>;
1743

18-
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
44+
def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
45+
let cppNamespace = "::mlir::triton::gpu";
1946

20-
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
47+
let description = [{
48+
Common trait describing shared memory.
49+
}];
50+
let methods = [
51+
InterfaceMethod<"Return the default alignment for the layout.",
52+
"int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>,
53+
];
54+
}
55+
def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods<
56+
SharedEncodingTrait, ["getAlignment"]>;
57+
58+
//===----------------------------------------------------------------------===//
59+
// Base Attribute
60+
//===----------------------------------------------------------------------===//
2161

22-
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
23-
Dialect dialect = TritonGPU_Dialect,
24-
string baseCppClass = "::mlir::Attribute">
25-
: AttrDef<dialect, name, !listconcat([TritonGPU_AttrTrait], traits), baseCppClass> {
62+
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = []>
63+
: AttrDef<TritonGPU_Dialect, name, traits> {
2664

2765
let description = [{
2866
TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
@@ -123,51 +161,17 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
123161
CTAOrder.push_back(i);
124162
return get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
125163
}
126-
unsigned getRank() const {
127-
return getCTAOrder().size();
128-
}
164+
unsigned getRank() const { return getCTAOrder().size(); }
129165
}];
130166

131167
let genVerifyDecl = 1;
132168
let skipDefaultBuilders = 1;
133169
}
134170

135-
136-
def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
137-
let cppNamespace = "::mlir::triton::gpu";
138-
let description = [{
139-
Common trait for all TTGIR layouts.
140-
}];
141-
let methods = [
142-
InterfaceMethod<"Get the shape of the CTAs per CGA.",
143-
"SmallVector<unsigned>",
144-
"getCTAsPerCGA">,
145-
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
146-
"SmallVector<unsigned>",
147-
"getCTAOrder">,
148-
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
149-
"SmallVector<unsigned>",
150-
"getCTASplitNum">,
151-
];
152-
}
153-
154171
//===----------------------------------------------------------------------===//
155172
// Shared Layout Encoding
156173
//===----------------------------------------------------------------------===//
157174

158-
def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
159-
let cppNamespace = "::mlir::triton::gpu";
160-
161-
let description = [{
162-
Common trait describing shared memory.
163-
}];
164-
let methods = [
165-
InterfaceMethod<"Return the default alignment for the layout.",
166-
"int32_t",
167-
"getAlignment">,
168-
];
169-
}
170-
171175
def SwizzledSharedEncodingAttr
172176
: TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
173177
[SharedEncodingTrait, LayoutEncodingTrait]> {
@@ -359,13 +363,6 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
359363
}]>,
360364
];
361365

362-
let extraClassDeclaration = extraBaseClassDeclaration # [{
363-
unsigned getRank() const { return getCTAOrder().size(); }
364-
int32_t getAlignment() const;
365-
SmallVector<unsigned> getCTAsPerCGA() const;
366-
SmallVector<unsigned> getCTAOrder() const;
367-
SmallVector<unsigned> getCTASplitNum() const;
368-
}];
369366
let hasCustomAssemblyFormat = 1;
370367
let genVerifyDecl = 1;
371368
}
@@ -433,27 +430,19 @@ attributes too, for example,
433430
];
434431

435432
let extraClassDeclaration = extraBaseClassDeclaration # [{
436-
unsigned getRank() const { return getOrder().size(); }
437-
int32_t getAlignment() const { return 16; }
438-
439433
unsigned getMinInterval() const {
440434
return *llvm::min_element(getIntervals());
441435
}
442436

443437
// Returns the total number of elements including padding given the input
444438
// tensor shape.
445439
int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
446-
447-
SmallVector<unsigned> getCTAsPerCGA() const;
448-
SmallVector<unsigned> getCTAOrder() const;
449-
SmallVector<unsigned> getCTASplitNum() const;
450440
}];
451441
let hasCustomAssemblyFormat = 1;
452442
let genVerifyDecl = 1;
453443
}
454444

455-
def NVMMASharedEncodingAttr :
456-
TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
445+
def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> {
457446
let mnemonic = "nvmma_shared";
458447

459448
let description = [{
@@ -513,11 +502,6 @@ def NVMMASharedEncodingAttr :
513502
];
514503

515504
let extraClassDeclaration = extraBaseClassDeclaration # [{
516-
unsigned getRank() const { return getCTAOrder().size(); }
517-
int32_t getAlignment() const;
518-
SmallVector<unsigned> getCTAsPerCGA() const;
519-
SmallVector<unsigned> getCTAOrder() const;
520-
SmallVector<unsigned> getCTASplitNum() const;
521505
int getPerPhase() const;
522506
int getMaxPhase() const;
523507
int getVec() const;
@@ -619,20 +603,14 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
619603
"CTALayoutAttr":$CTALayout
620604
);
621605

622-
let extraClassDeclaration = extraBaseClassDeclaration # [{
623-
unsigned getRank() const { return getCTAOrder().size(); }
624-
int32_t getAlignment() const;
625-
SmallVector<unsigned> getCTAsPerCGA() const;
626-
SmallVector<unsigned> getCTAOrder() const;
627-
SmallVector<unsigned> getCTASplitNum() const;
628-
}];
629606
let hasCustomAssemblyFormat = 1;
630607
}
631608

632609

633610
//===----------------------------------------------------------------------===//
634611
// Distributed Layout Encoding
635612
//===----------------------------------------------------------------------===//
613+
636614
def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> {
637615
let cppNamespace = "::mlir::triton::gpu";
638616

@@ -681,9 +659,8 @@ We call each individual tile "rep".
681659
];
682660
}
683661

684-
class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = [],
685-
Dialect dialect = TritonGPU_Dialect>
686-
: TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits), dialect> {
662+
class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = []>
663+
: TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits)> {
687664

688665
let description = [{
689666
Distributed encodings have a layout function L that is entirely characterized
@@ -719,12 +696,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
719696
}];
720697

721698
code extraDistributedDeclaration = extraBaseClassDeclaration # [{
722-
unsigned getRank() const { return getCTAOrder().size(); }
723699
// Implemented in subclasses
724700
SmallVector<unsigned> getRepOrder() const;
725-
SmallVector<unsigned> getCTAsPerCGA() const;
726-
SmallVector<unsigned> getCTAOrder() const;
727-
SmallVector<unsigned> getCTASplitNum() const;
728701

729702
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
730703
}];
@@ -739,7 +712,7 @@ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
739712
let cppAccessorType = "const LinearLayout &";
740713
}
741714

742-
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
715+
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods]> {
743716
let mnemonic = "linear";
744717

745718
let description = [{
@@ -1376,7 +1349,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
13761349
let hasCustomAssemblyFormat = 1;
13771350
}
13781351

1379-
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
1352+
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods]> {
13801353
let mnemonic = "slice";
13811354

13821355
let description = [{
@@ -1419,9 +1392,10 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
14191392
}];
14201393

14211394
let hasCustomAssemblyFormat = 1;
1395+
let genVerifyDecl = 1;
14221396
}
14231397

1424-
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> {
1398+
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods]> {
14251399
let mnemonic = "dot_op";
14261400

14271401
let description = [{

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2828
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2929
#include "mlir/IR/BuiltinOps.h"
30+
#include "mlir/IR/BuiltinTypes.h"
3031
#include "mlir/IR/Dialect.h"
3132

3233
// TritonNvidiaGPU depends on Triton
3334
#include "triton/Dialect/Triton/IR/Dialect.h"
3435
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
36+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
3537
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
3638

3739
namespace mlir::triton::nvidia_gpu::impl {
@@ -61,13 +63,19 @@ struct TMemAllocation {
6163

6264
TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);
6365

64-
Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
65-
RankedTensorType oltType, unsigned numWarps);
66+
gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N,
67+
RankedTensorType oltType,
68+
unsigned numWarps);
69+
gpu::DistributedEncodingTrait
70+
getTmemLoadLayoutSplitLongM(RankedTensorType tensorType,
71+
gpu::MemDescType memType, int numWarps);
72+
SmallVector<gpu::DistributedEncodingTrait>
73+
getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType,
74+
gpu::MemDescType memType);
6675

6776
bool isDistributedLayoutTMemCompatible(Operation *op,
6877
RankedTensorType tensorType,
6978
gpu::MemDescType memType);
70-
7179
bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
7280
gpu::MemDescType memType,
7381
int numWarps);

0 commit comments

Comments
 (0)