Skip to content

Commit afe1e9d

Browse files
Revert "[Dialect] Layout attr cleanup and tighten invariants (#7714)"
This reverts commit 96e53bb.
1 parent 9587695 commit afe1e9d

File tree

19 files changed

+494
-513
lines changed

19 files changed

+494
-513
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 82 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,23 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
66
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
77

88
//===----------------------------------------------------------------------===//
9-
// Traits and Interfaces
9+
// TritonGPU Attribute Definitions
1010
//===----------------------------------------------------------------------===//
11-
12-
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
13-
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
14-
15-
def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
16-
let cppNamespace = "::mlir::triton::gpu";
17-
let description = [{
18-
Common trait for all TTGIR layouts.
19-
}];
20-
let methods = [
21-
InterfaceMethod<"Get the shape of the CTAs per CGA.",
22-
"SmallVector<unsigned>",
23-
"getCTAsPerCGA", (ins), [{}], [{
24-
return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA());
25-
}]>,
26-
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
27-
"SmallVector<unsigned>",
28-
"getCTAOrder", (ins), [{}], [{
29-
return llvm::to_vector($_attr.getCTALayout().getCTAOrder());
30-
}]>,
31-
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
32-
"SmallVector<unsigned>",
33-
"getCTASplitNum", (ins), [{}], [{
34-
return llvm::to_vector($_attr.getCTALayout().getCTASplitNum());
35-
}]>,
36-
InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{
37-
return $_attr.getCTAOrder().size();
38-
}]>
39-
];
40-
}
41-
def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods<
42-
LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>;
43-
44-
def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
11+
def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
4512
let cppNamespace = "::mlir::triton::gpu";
4613

47-
let description = [{
48-
Common trait describing shared memory.
49-
}];
5014
let methods = [
51-
InterfaceMethod<"Return the default alignment for the layout.",
52-
"int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>,
5315
];
5416
}
55-
def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods<
56-
SharedEncodingTrait, ["getAlignment"]>;
5717

58-
//===----------------------------------------------------------------------===//
59-
// Base Attribute
60-
//===----------------------------------------------------------------------===//
18+
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
19+
20+
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
6121

62-
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = []>
63-
: AttrDef<TritonGPU_Dialect, name, traits> {
22+
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
23+
Dialect dialect = TritonGPU_Dialect,
24+
string baseCppClass = "::mlir::Attribute">
25+
: AttrDef<dialect, name, !listconcat([TritonGPU_AttrTrait], traits), baseCppClass> {
6426

6527
let description = [{
6628
TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
@@ -161,17 +123,51 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
161123
CTAOrder.push_back(i);
162124
return get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
163125
}
164-
unsigned getRank() const { return getCTAOrder().size(); }
126+
unsigned getRank() const {
127+
return getCTAOrder().size();
128+
}
165129
}];
166130

167131
let genVerifyDecl = 1;
168132
let skipDefaultBuilders = 1;
169133
}
170134

135+
136+
def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
137+
let cppNamespace = "::mlir::triton::gpu";
138+
let description = [{
139+
Common trait for all TTGIR layouts.
140+
}];
141+
let methods = [
142+
InterfaceMethod<"Get the shape of the CTAs per CGA.",
143+
"SmallVector<unsigned>",
144+
"getCTAsPerCGA">,
145+
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
146+
"SmallVector<unsigned>",
147+
"getCTAOrder">,
148+
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
149+
"SmallVector<unsigned>",
150+
"getCTASplitNum">,
151+
];
152+
}
153+
171154
//===----------------------------------------------------------------------===//
172155
// Shared Layout Encoding
173156
//===----------------------------------------------------------------------===//
174157

158+
def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
159+
let cppNamespace = "::mlir::triton::gpu";
160+
161+
let description = [{
162+
Common trait describing shared memory.
163+
}];
164+
let methods = [
165+
InterfaceMethod<"Return the default alignment for the layout.",
166+
"int32_t",
167+
"getAlignment">,
168+
];
169+
}
170+
175171
def SwizzledSharedEncodingAttr
176172
: TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
177173
[SharedEncodingTrait, LayoutEncodingTrait]> {
@@ -363,6 +359,13 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
363359
}]>,
364360
];
365361

362+
let extraClassDeclaration = extraBaseClassDeclaration # [{
363+
unsigned getRank() const { return getCTAOrder().size(); }
364+
int32_t getAlignment() const;
365+
SmallVector<unsigned> getCTAsPerCGA() const;
366+
SmallVector<unsigned> getCTAOrder() const;
367+
SmallVector<unsigned> getCTASplitNum() const;
368+
}];
366369
let hasCustomAssemblyFormat = 1;
367370
let genVerifyDecl = 1;
368371
}
@@ -430,19 +433,27 @@ attributes too, for example,
430433
];
431434

432435
let extraClassDeclaration = extraBaseClassDeclaration # [{
436+
unsigned getRank() const { return getOrder().size(); }
437+
int32_t getAlignment() const { return 16; }
438+
433439
unsigned getMinInterval() const {
434440
return *llvm::min_element(getIntervals());
435441
}
436442

437443
// Returns the total number of elements including padding given the input
438444
// tensor shape.
439445
int64_t getPaddedSize(ArrayRef<int64_t> shape) const;
446+
447+
SmallVector<unsigned> getCTAsPerCGA() const;
448+
SmallVector<unsigned> getCTAOrder() const;
449+
SmallVector<unsigned> getCTASplitNum() const;
440450
}];
441451
let hasCustomAssemblyFormat = 1;
442452
let genVerifyDecl = 1;
443453
}
444454

445-
def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> {
455+
def NVMMASharedEncodingAttr :
456+
TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
446457
let mnemonic = "nvmma_shared";
447458

448459
let description = [{
@@ -502,6 +513,11 @@ def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_share
502513
];
503514

504515
let extraClassDeclaration = extraBaseClassDeclaration # [{
516+
unsigned getRank() const { return getCTAOrder().size(); }
517+
int32_t getAlignment() const;
518+
SmallVector<unsigned> getCTAsPerCGA() const;
519+
SmallVector<unsigned> getCTAOrder() const;
520+
SmallVector<unsigned> getCTASplitNum() const;
505521
int getPerPhase() const;
506522
int getMaxPhase() const;
507523
int getVec() const;
@@ -603,14 +619,20 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
603619
"CTALayoutAttr":$CTALayout
604620
);
605621

622+
let extraClassDeclaration = extraBaseClassDeclaration # [{
623+
unsigned getRank() const { return getCTAOrder().size(); }
624+
int32_t getAlignment() const;
625+
SmallVector<unsigned> getCTAsPerCGA() const;
626+
SmallVector<unsigned> getCTAOrder() const;
627+
SmallVector<unsigned> getCTASplitNum() const;
628+
}];
606629
let hasCustomAssemblyFormat = 1;
607630
}
608631

609632

610633
//===----------------------------------------------------------------------===//
611634
// Distributed Layout Encoding
612635
//===----------------------------------------------------------------------===//
613-
614636
def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> {
615637
let cppNamespace = "::mlir::triton::gpu";
616638

@@ -659,8 +681,9 @@ We call each individual tile "rep".
659681
];
660682
}
661683

662-
class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = []>
663-
: TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits)> {
684+
class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = [],
685+
Dialect dialect = TritonGPU_Dialect>
686+
: TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits), dialect> {
664687

665688
let description = [{
666689
Distributed encodings have a layout function L that is entirely characterized
@@ -696,8 +719,12 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
696719
}];
697720

698721
code extraDistributedDeclaration = extraBaseClassDeclaration # [{
722+
unsigned getRank() const { return getCTAOrder().size(); }
699723
// Implemented in subclasses
700724
SmallVector<unsigned> getRepOrder() const;
725+
SmallVector<unsigned> getCTAsPerCGA() const;
726+
SmallVector<unsigned> getCTAOrder() const;
727+
SmallVector<unsigned> getCTASplitNum() const;
701728

702729
LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
703730
}];
@@ -712,7 +739,7 @@ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
712739
let cppAccessorType = "const LinearLayout &";
713740
}
714741

715-
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods]> {
742+
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
716743
let mnemonic = "linear";
717744

718745
let description = [{
@@ -1349,7 +1376,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
13491376
let hasCustomAssemblyFormat = 1;
13501377
}
13511378

1352-
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods]> {
1379+
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
13531380
let mnemonic = "slice";
13541381

13551382
let description = [{
@@ -1392,10 +1419,9 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [
13921419
}];
13931420

13941421
let hasCustomAssemblyFormat = 1;
1395-
let genVerifyDecl = 1;
13961422
}
13971423

1398-
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods]> {
1424+
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> {
13991425
let mnemonic = "dot_op";
14001426

14011427
let description = [{

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,11 @@
2727
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2828
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2929
#include "mlir/IR/BuiltinOps.h"
30-
#include "mlir/IR/BuiltinTypes.h"
3130
#include "mlir/IR/Dialect.h"
3231

3332
// TritonNvidiaGPU depends on Triton
3433
#include "triton/Dialect/Triton/IR/Dialect.h"
3534
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
36-
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
3735
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"
3836

3937
namespace mlir::triton::nvidia_gpu::impl {
@@ -63,19 +61,13 @@ struct TMemAllocation {
6361

6462
TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);
6563

66-
gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N,
67-
RankedTensorType oltType,
68-
unsigned numWarps);
69-
gpu::DistributedEncodingTrait
70-
getTmemLoadLayoutSplitLongM(RankedTensorType tensorType,
71-
gpu::MemDescType memType, int numWarps);
72-
SmallVector<gpu::DistributedEncodingTrait>
73-
getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType,
74-
gpu::MemDescType memType);
64+
Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
65+
RankedTensorType oltType, unsigned numWarps);
7566

7667
bool isDistributedLayoutTMemCompatible(Operation *op,
7768
RankedTensorType tensorType,
7869
gpu::MemDescType memType);
70+
7971
bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
8072
gpu::MemDescType memType,
8173
int numWarps);

0 commit comments

Comments
 (0)