Skip to content

Commit 37b03ae

Browse files
Reland "[Dialect] Layout attr cleanup and tighten invariants (#7714)" (#4919)
2 parents b0acd14 + c9466d4 commit 37b03ae

File tree

21 files changed

+512
-506
lines changed

21 files changed

+512
-506
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 54 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,61 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
66
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
77

88
//===----------------------------------------------------------------------===//
9-
// TritonGPU Attribute Definitions
9+
// Traits and Interfaces
1010
//===----------------------------------------------------------------------===//
11-
def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
12-
let cppNamespace = "::mlir::triton::gpu";
1311

12+
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
13+
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
14+
15+
def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
16+
let cppNamespace = "::mlir::triton::gpu";
17+
let description = [{
18+
Common trait for all TTGIR layouts.
19+
}];
1420
let methods = [
21+
InterfaceMethod<"Get the shape of the CTAs per CGA.",
22+
"SmallVector<unsigned>",
23+
"getCTAsPerCGA", (ins), [{}], [{
24+
return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA());
25+
}]>,
26+
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
27+
"SmallVector<unsigned>",
28+
"getCTAOrder", (ins), [{}], [{
29+
return llvm::to_vector($_attr.getCTALayout().getCTAOrder());
30+
}]>,
31+
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
32+
"SmallVector<unsigned>",
33+
"getCTASplitNum", (ins), [{}], [{
34+
return llvm::to_vector($_attr.getCTALayout().getCTASplitNum());
35+
}]>,
36+
InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{
37+
return $_attr.getCTAOrder().size();
38+
}]>
1539
];
1640
}
41+
def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods<
42+
LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>;
1743

18-
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
44+
def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
45+
let cppNamespace = "::mlir::triton::gpu";
1946

20-
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
47+
let description = [{
48+
Common trait describing shared memory.
49+
}];
50+
let methods = [
51+
InterfaceMethod<"Return the default alignment for the layout.",
52+
"int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>,
53+
];
54+
}
55+
def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods<
56+
SharedEncodingTrait, ["getAlignment"]>;
57+
58+
//===----------------------------------------------------------------------===//
59+
// Base Attribute
60+
//===----------------------------------------------------------------------===//
2161

22-
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
23-
Dialect dialect = TritonGPU_Dialect,
24-
string baseCppClass = "::mlir::Attribute">
25-
: AttrDef<dialect, name, !listconcat([TritonGPU_AttrTrait], traits), baseCppClass> {
62+
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [], Dialect dialect = TritonGPU_Dialect>
63+
: AttrDef<dialect, name, traits> {
2664

2765
let description = [{
2866
TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
@@ -123,51 +161,17 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
123161
CTAOrder.push_back(i);
124162
return get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
125163
}
126-
unsigned getRank() const {
127-
return getCTAOrder().size();
128-
}
164+
unsigned getRank() const { return getCTAOrder().size(); }
129165
}];
130166

131167
let genVerifyDecl = 1;
132168
let skipDefaultBuilders = 1;
133169
}
134170

135-
136-
def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
137-
let cppNamespace = "::mlir::triton::gpu";
138-
let description = [{
139-
Common trait for all TTGIR layouts.
140-
}];
141-
let methods = [
142-
InterfaceMethod<"Get the shape of the CTAs per CGA.",
143-
"SmallVector<unsigned>",
144-
"getCTAsPerCGA">,
145-
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
146-
"SmallVector<unsigned>",
147-
"getCTAOrder">,
148-
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
149-
"SmallVector<unsigned>",
150-
"getCTASplitNum">,
151-
];
152-
}
153-
154171
//===----------------------------------------------------------------------===//
155172
// Shared Layout Encoding
156173
//===----------------------------------------------------------------------===//
157174

158-
def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
159-
let cppNamespace = "::mlir::triton::gpu";
160-
161-
let description = [{
162-
Common trait describing shared memory.
163-
}];
164-
let methods = [
165-
InterfaceMethod<"Return the default alignment for the layout.",
166-
"int32_t",
167-
"getAlignment">,
168-
];
169-
}
170-
171175
def SwizzledSharedEncodingAttr
172176
: TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
173177
[SharedEncodingTrait, LayoutEncodingTrait]> {
@@ -359,13 +363,6 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
359363
}]>,
360364
];
361365

362-
let extraClassDeclaration = extraBaseClassDeclaration # [{
363-
unsigned getRank() const { return getCTAOrder().size(); }
364-
int32_t getAlignment() const;
365-
SmallVector<unsigned> getCTAsPerCGA() const;
366-
SmallVector<unsigned> getCTAOrder() const;
367-
SmallVector<unsigned> getCTASplitNum() const;
368-
}];
369366
let hasCustomAssemblyFormat = 1;
370367
let genVerifyDecl = 1;
371368
}
@@ -433,27 +430,19 @@ attributes too, for example,
433430
];
434431

435432
let extraClassDeclaration = extraBaseClassDeclaration # [{
436-
unsigned getRank() const { return getOrder().size(); }
437-
int32_t getAlignment() const { return 16; }
438-
439433
unsigned getMinInterval() const {
440434
return *llvm::min_element(getIntervals());
441435
}
442436

443437
// Returns the total number of elements including padding given the input
444438
// tensor shape.
445439
int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
446-
447-
SmallVector<unsigned> getCTAsPerCGA() const;
448-
SmallVector<unsigned> getCTAOrder() const;
449-
SmallVector<unsigned> getCTASplitNum() const;
450440
}];
451441
let hasCustomAssemblyFormat = 1;
452442
let genVerifyDecl = 1;
453443
}
454444

455-
def NVMMASharedEncodingAttr :
456-
TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
445+
def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> {
457446
let mnemonic = "nvmma_shared";
458447

459448
let description = [{
@@ -513,11 +502,6 @@ def NVMMASharedEncodingAttr :
513502
];
514503

515504
let extraClassDeclaration = extraBaseClassDeclaration # [{
516-
unsigned getRank() const { return getCTAOrder().size(); }
517-
int32_t getAlignment() const;
518-
SmallVector<unsigned> getCTAsPerCGA() const;
519-
SmallVector<unsigned> getCTAOrder() const;
520-
SmallVector<unsigned> getCTASplitNum() const;
521505
int getPerPhase() const;
522506
int getMaxPhase() const;
523507
int getVec() const;
@@ -619,20 +603,14 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
619603
"CTALayoutAttr":$CTALayout
620604
);
621605

622-
let extraClassDeclaration = extraBaseClassDeclaration # [{
623-
unsigned getRank() const { return getCTAOrder().size(); }
624-
int32_t getAlignment() const;
625-
SmallVector<unsigned> getCTAsPerCGA() const;
626-
SmallVector<unsigned> getCTAOrder() const;
627-
SmallVector<unsigned> getCTASplitNum() const;
628-
}];
629606
let hasCustomAssemblyFormat = 1;
630607
}
631608

632609

633610
//===----------------------------------------------------------------------===//
634611
// Distributed Layout Encoding
635612
//===----------------------------------------------------------------------===//
613+
636614
def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> {
637615
let cppNamespace = "::mlir::triton::gpu";
638616

@@ -719,12 +697,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
719697
}];
720698

721699
code extraDistributedDeclaration = extraBaseClassDeclaration # [{
722-
unsigned getRank() const { return getCTAOrder().size(); }
723700
// Implemented in subclasses
724701
SmallVector<unsigned> getRepOrder() const;
725-
SmallVector<unsigned> getCTAsPerCGA() const;
726-
SmallVector<unsigned> getCTAOrder() const;
727-
SmallVector<unsigned> getCTASplitNum() const;
728702

729703
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
730704
}];
@@ -739,7 +713,7 @@ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
739713
let cppAccessorType = "const LinearLayout &";
740714
}
741715

742-
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
716+
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods]> {
743717
let mnemonic = "linear";
744718

745719
let description = [{
@@ -1376,7 +1350,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
13761350
let hasCustomAssemblyFormat = 1;
13771351
}
13781352

1379-
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
1353+
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods]> {
13801354
let mnemonic = "slice";
13811355

13821356
let description = [{
@@ -1419,9 +1393,10 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
14191393
}];
14201394

14211395
let hasCustomAssemblyFormat = 1;
1396+
let genVerifyDecl = 1;
14221397
}
14231398

1424-
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> {
1399+
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods]> {
14251400
let mnemonic = "dot_op";
14261401

14271402
let description = [{

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2828
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2929
#include "mlir/IR/BuiltinOps.h"
30+
#include "mlir/IR/BuiltinTypes.h"
3031
#include "mlir/IR/Dialect.h"
3132

3233
// TritonNvidiaGPU depends on Triton
3334
#include "triton/Dialect/Triton/IR/Dialect.h"
3435
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
36+
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
3537
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
3638

3739
namespace mlir::triton::nvidia_gpu::impl {
@@ -61,13 +63,19 @@ struct TMemAllocation {
6163

6264
TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);
6365

64-
Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
65-
RankedTensorType oltType, unsigned numWarps);
66+
gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N,
67+
RankedTensorType oltType,
68+
unsigned numWarps);
69+
gpu::DistributedEncodingTrait
70+
getTmemLoadLayoutSplitLongM(RankedTensorType tensorType,
71+
gpu::MemDescType memType, int numWarps);
72+
SmallVector<gpu::DistributedEncodingTrait>
73+
getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType,
74+
gpu::MemDescType memType);
6675

6776
bool isDistributedLayoutTMemCompatible(Operation *op,
6877
RankedTensorType tensorType,
6978
gpu::MemDescType memType);
70-
7179
bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
7280
gpu::MemDescType memType,
7381
int numWarps);

0 commit comments

Comments
 (0)