@@ -12,17 +12,6 @@ def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
1212 let cppNamespace = "::mlir::triton::gpu";
1313
1414 let methods = [
15- InterfaceMethod<"Return total element size per thread.",
16- "unsigned",
17- "getTotalElemsPerThread",
18- (ins "ArrayRef<int64_t>":$tensorShape,
19- "Type":$eltTy)>,
20-
21- InterfaceMethod<"Return element size per thread in each dimension.",
22- "SmallVector<unsigned>",
23- "getElemsPerThread",
24- (ins "ArrayRef<int64_t>":$tensorShape,
25- "Type":$eltTy)>,
2615 ];
2716}
2817
@@ -54,8 +43,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute
5443 let attrName = "triton.gpu." # attrMnemonic;
5544
5645 code extraBaseClassDeclaration = [{
57- unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
58- SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
5946 }];
6047}
6148
@@ -124,15 +111,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
124111 ];
125112
126113 let extraClassDeclaration = [{
127- SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
128- llvm::report_fatal_error(
129- "Unsupported getElemsPerThread in CTALayoutAttr.");
130- }
131- unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
132- llvm::report_fatal_error(
133- "Unsupported getTotalElemsPerThread in CTALayoutAttr.");
134- }
135-
136114 static CTALayoutAttr getDefault(MLIRContext *context, int rank) {
137115 SmallVector<unsigned> CTAsPerCGA(rank, 1);
138116 SmallVector<unsigned> CTASplitNum(rank, 1);
@@ -146,12 +124,46 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
146124 let genVerifyDecl = 1;
147125 let skipDefaultBuilders = 1;
148126}
127+
128+
129+ def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
130+ let cppNamespace = "::mlir::triton::gpu";
131+ let description = [{
132+ Common trait for all TTGIR layouts.
133+ }];
134+ let methods = [
135+ InterfaceMethod<"Get the shape of the CTAs per CGA.",
136+ "SmallVector<unsigned>",
137+ "getCTAsPerCGA">,
138+ InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
139+ "SmallVector<unsigned>",
140+ "getCTAOrder">,
141+ InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
142+ "SmallVector<unsigned>",
143+ "getCTASplitNum">,
144+ ];
145+ }
146+
149147//===----------------------------------------------------------------------===//
150148// Shared Layout Encoding
151149//===----------------------------------------------------------------------===//
152150
153- def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> {
154- let mnemonic = "shared";
151+ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
152+ let cppNamespace = "::mlir::triton::gpu";
153+
154+ let description = [{
155+ Common trait describing shared memory.
156+ }];
157+ let methods = [
158+ InterfaceMethod<"Return the default alignment for the layout.",
159+ "int32_t",
160+ "getAlignment">,
161+ ];
162+ }
163+
164+ def SwizzledSharedEncodingAttr :
165+ TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
166+ let mnemonic = "swizzled_shared";
155167
156168 let description = [{
157169An encoding for tensors whose elements may be simultaneously accessed by
@@ -226,13 +238,6 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
226238(r,c) has value
227239
228240 ((c / 2) ^ r) * 2 + (c % 2).
229-
230- For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case,
231- when the matrix is stored in shared memory, there will be an offset not
232- only in the stride dimension, but also in the leading dimension. For example,
233- a matrix of size 16x128 and data type I8 is stored in the shared memory with
234- 64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64,
235- compared to 1*64 when the hasLeadingOffset is false.
236241 }];
237242
238243 // swizzle info: vec, perPhase, maxPhase
@@ -243,20 +248,10 @@ compared to 1*64 when the hasLeadingOffset is false.
243248 "unsigned":$perPhase,
244249 "unsigned":$maxPhase,
245250 ArrayRefParameter<"unsigned">:$order,
246- "CTALayoutAttr":$CTALayout,
247- "bool":$hasLeadingOffset
251+ "CTALayoutAttr":$CTALayout
248252 );
249253
250254 let builders = [
251- AttrBuilder<(ins "unsigned":$vec,
252- "unsigned":$perPhase,
253- "unsigned":$maxPhase,
254- "ArrayRef<unsigned>":$order,
255- "CTALayoutAttr":$CTALayout), [{
256- bool hasLeadingOffset = false; // default value
257- return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset);
258- }]>,
259-
260255 AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
261256 "ArrayRef<int64_t>":$shape,
262257 "ArrayRef<unsigned>":$order,
@@ -267,7 +262,7 @@ compared to 1*64 when the hasLeadingOffset is false.
267262 }]>,
268263
269264 // TODO(jlebar): This should not be an overload of
270- // SharedEncodingAttr ::get(). It's misleading, because it does a bunch of
265+ // SwizzledSharedEncodingAttr ::get(). It's misleading, because it does a bunch of
271266 // nontrivial work based on the given dotOpEnc.
272267 AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
273268 "ArrayRef<int64_t>":$shape,
@@ -402,38 +397,66 @@ compared to 1*64 when the hasLeadingOffset is false.
402397 unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
403398 return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans);
404399 }]>,
400+ ];
405401
402+ let extraClassDeclaration = extraBaseClassDeclaration # [{
403+ int32_t getAlignment() const;
404+ SmallVector<unsigned> getCTAsPerCGA() const;
405+ SmallVector<unsigned> getCTAOrder() const;
406+ SmallVector<unsigned> getCTASplitNum() const;
407+ }];
408+ let hasCustomAssemblyFormat = 1;
409+ }
410+
411+ def NVMMASharedEncodingAttr :
412+ TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
413+ let mnemonic = "nvmma_shared";
414+
415+ let description = [{
416+ Represent blocked shared memory matching MMAv3/MMAv5 shared memory input.
417+ This is meant to represent 2d tiled blocked layout.
418+ The full layout representation is described here:
419+ https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout
420+ }];
421+
422+ let parameters = (
423+ ins
424+ "unsigned":$swizzlingByteWidth,
425+ "bool":$transposed,
426+ "CTALayoutAttr":$CTALayout
427+ );
428+
429+ let builders = [
406430 AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
407431 "ArrayRef<unsigned>":$order,
408432 "CTALayoutAttr":$CTALayout,
409433 "Type":$eltTy), [{
410434 auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
411-
435+ int32_t swizzlingByteWidth = 0;
412436 int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth();
413- int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1;
414437
415438 // get proper shared memory swizzling mode from the contiguous dimension
416439 // size of the origin blocked layout.
417440 auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8;
418441 if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) {
419- perPhase = 1;
420- maxPhase = 8;
442+ swizzlingByteWidth = 128;
421443 } else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) {
422- perPhase = 2;
423- maxPhase = 4;
444+ swizzlingByteWidth = 64;
424445 } else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) {
425- perPhase = 4;
426- maxPhase = 2;
446+ swizzlingByteWidth = 32;
427447 } else {
428448 llvm_unreachable("unsupported shared memory layout for MMAv3");
429449 }
430-
431- return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true );
450+ bool transposed = order[0] == 0;
451+ return $_get(context, swizzlingByteWidth, transposed, CTALayout);
432452 }]>
433453 ];
434454
435455 let extraClassDeclaration = extraBaseClassDeclaration # [{
436456 int32_t getAlignment() const;
457+ SmallVector<unsigned> getCTAsPerCGA() const;
458+ SmallVector<unsigned> getCTAOrder() const;
459+ SmallVector<unsigned> getCTASplitNum() const;
437460 }];
438461 let hasCustomAssemblyFormat = 1;
439462}
@@ -468,16 +491,17 @@ We call each individual tile "rep".
468491 InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
469492 "SmallVector<unsigned>",
470493 "getRepOrder">,
471-
472- // Interface for the meta information about the multiple thread hierarchy.
473- InterfaceMethod<"Get the shape of the CTAs per CGA.",
474- "SmallVector<unsigned>",
475- "getCTAsPerCGA">,
476-
477- InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
494+ InterfaceMethod<"Return total element size per thread.",
495+ "unsigned",
496+ "getTotalElemsPerThread",
497+ (ins "ArrayRef<int64_t>":$tensorShape,
498+ "Type":$eltTy)>,
499+ InterfaceMethod<"Return element size per thread in each dimension.",
478500 "SmallVector<unsigned>",
479- "getCTAOrder">,
480-
501+ "getElemsPerThread",
502+ (ins "ArrayRef<int64_t>":$tensorShape,
503+ "Type":$eltTy)>,
504+ // Interface for the meta information about the multiple thread hierarchy.
481505 InterfaceMethod<"Get the shape of the warps per CTA.",
482506 "SmallVector<unsigned>",
483507 "getWarpsPerCTA">,
@@ -498,10 +522,6 @@ We call each individual tile "rep".
498522 "SmallVector<unsigned>",
499523 "getSizePerThread">,
500524
501- InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
502- "SmallVector<unsigned>",
503- "getCTASplitNum">,
504-
505525 InterfaceMethod<"Gets the number of contiguous elements per thread.",
506526 "SmallVector<unsigned>",
507527 "getContigPerThread">,
@@ -514,7 +534,7 @@ We call each individual tile "rep".
514534
515535class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = [],
516536 Dialect dialect = TritonGPU_Dialect>
517- : TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait], traits), dialect> {
537+ : TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait ], traits), dialect> {
518538
519539 let description = [{
520540Distributed encodings have a layout function L that is entirely characterized
@@ -550,6 +570,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
550570 }];
551571
552572 code extraDistributedDeclaration = extraBaseClassDeclaration # [{
573+ unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
574+ SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
553575 SmallVector<unsigned> getRepOrder() const;
554576 SmallVector<unsigned> getCTAsPerCGA() const;
555577 SmallVector<unsigned> getCTAOrder() const;
0 commit comments