Skip to content

Commit 00056fa

Browse files
authored
Merge OpenAI Triton commit 2778526 (#3332)
This PR change the Triton base from a637eb2 to 2778526 (Feb 1). Pass rate: 98.19% Please do not squash and merge this PR.
2 parents 94efbc1 + b631910 commit 00056fa

File tree

114 files changed

+1030
-1217
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

114 files changed

+1030
-1217
lines changed

bin/triton-tensor-layout.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ static cl::opt<std::string> TensorStr(
8080
//===--------------------------------------------------------------------===//
8181

8282
LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
83-
// DistributedEncodingTrait and SharedEncodingAttr implements the
83+
// DistributedEncodingTrait and SharedEncodingTrait implements the
8484
// toLinearLayout interface.
8585
mlir::Attribute layout = tensorType.getEncoding();
8686
if (isa<mlir::triton::gpu::DistributedEncodingTrait,
87-
mlir::triton::gpu::SharedEncodingAttr>(layout)) {
87+
mlir::triton::gpu::SharedEncodingTrait>(layout)) {
8888
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
8989
return success();
9090
}

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,9 +1288,12 @@ inline Value packLLVector(Location loc, ValueRange vals,
12881288
inline bool
12891289
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
12901290
ArrayRef<int64_t> allocShape,
1291-
triton::gpu::SharedEncodingAttr sharedEnc) {
1291+
triton::gpu::SharedEncodingTrait sharedEnc) {
12921292
auto rank = shape.size();
1293-
return /*no swizzling*/ sharedEnc.getMaxPhase() == 1 ||
1293+
auto swizzledLayout =
1294+
dyn_cast<triton::gpu::SwizzledSharedEncodingAttr>(sharedEnc);
1295+
bool noSwizzling = swizzledLayout && swizzledLayout.getMaxPhase() == 1;
1296+
return /*no swizzling*/ noSwizzling ||
12941297
/*swizzling but same shape*/ shape == allocShape ||
12951298
/*swizzling and rank-reduced and rank >= 2*/
12961299
(shape == allocShape.take_back(rank) && rank >= 2);

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
#include "triton/Tools/LinearLayout.h"
1010

1111
namespace mlir::triton::gpu {
12-
class SharedEncodingAttr;
12+
class SwizzledSharedEncodingAttr;
13+
class NVMMASharedEncodingAttr;
1314

1415
// - BlockedEncodingAttrs have the following input dimensions.
1516
//
@@ -18,7 +19,8 @@ class SharedEncodingAttr;
1819
// "warp": warps in a block/CTA
1920
// "block": blocks in a cluster
2021
//
21-
// - An n-dimensional SharedEncodingAttr has the following input dimensions.
22+
// - An n-dimensional SwizzledSharedEncodingAttr has the following input
23+
// dimensions.
2224
//
2325
// "offset": the n'th element in the allocation, within a particular thread
2426
// block (i.e. within a CTA). The offset is measured in elements, not
@@ -36,19 +38,19 @@ class SharedEncodingAttr;
3638
//
3739
// elemBitWidth is the bit width of one element in the layout. This is required
3840
// to compute the linear layout for MMAv3 (i.e. Hopper) shared layouts (i.e.
39-
// shared layouts with hasLeadingOffset == true) but is otherwise unused.
41+
// shared layouts with nvmma_shared layout) but is otherwise unused.
4042
//
4143
// Returns std::nullopt if the given layout can't be converted to an LL.
4244
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
4345
std::optional<int32_t> elemBitWidth = std::nullopt);
4446

45-
// Convert the shared encoding of a tensor with `hasLeadingOffset=true` to a
47+
// Convert the shared encoding of a tensor with `nvmma_shared` layout to a
4648
// LinearLayout that maps from a linear shared memory offset to tensor index.
4749
//
4850
// If `disableSwizzle` is set, then the resulting layout does not include
4951
// swizzling.
5052
LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
51-
SharedEncodingAttr shared,
53+
NVMMASharedEncodingAttr shared,
5254
int32_t elemBitWidth,
5355
bool disableSwizzle = false);
5456

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 89 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,6 @@ def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
1212
let cppNamespace = "::mlir::triton::gpu";
1313

1414
let methods = [
15-
InterfaceMethod<"Return total element size per thread.",
16-
"unsigned",
17-
"getTotalElemsPerThread",
18-
(ins "ArrayRef<int64_t>":$tensorShape,
19-
"Type":$eltTy)>,
20-
21-
InterfaceMethod<"Return element size per thread in each dimension.",
22-
"SmallVector<unsigned>",
23-
"getElemsPerThread",
24-
(ins "ArrayRef<int64_t>":$tensorShape,
25-
"Type":$eltTy)>,
2615
];
2716
}
2817

@@ -54,8 +43,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute
5443
let attrName = "triton.gpu." # attrMnemonic;
5544

5645
code extraBaseClassDeclaration = [{
57-
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
58-
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
5946
}];
6047
}
6148

@@ -124,15 +111,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
124111
];
125112

126113
let extraClassDeclaration = [{
127-
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
128-
llvm::report_fatal_error(
129-
"Unsupported getElemsPerThread in CTALayoutAttr.");
130-
}
131-
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
132-
llvm::report_fatal_error(
133-
"Unsupported getTotalElemsPerThread in CTALayoutAttr.");
134-
}
135-
136114
static CTALayoutAttr getDefault(MLIRContext *context, int rank) {
137115
SmallVector<unsigned> CTAsPerCGA(rank, 1);
138116
SmallVector<unsigned> CTASplitNum(rank, 1);
@@ -146,12 +124,46 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
146124
let genVerifyDecl = 1;
147125
let skipDefaultBuilders = 1;
148126
}
127+
128+
129+
def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
130+
let cppNamespace = "::mlir::triton::gpu";
131+
let description = [{
132+
Common trait for all TTGIR layouts.
133+
}];
134+
let methods = [
135+
InterfaceMethod<"Get the shape of the CTAs per CGA.",
136+
"SmallVector<unsigned>",
137+
"getCTAsPerCGA">,
138+
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
139+
"SmallVector<unsigned>",
140+
"getCTAOrder">,
141+
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
142+
"SmallVector<unsigned>",
143+
"getCTASplitNum">,
144+
];
145+
}
146+
149147
//===----------------------------------------------------------------------===//
150148
// Shared Layout Encoding
151149
//===----------------------------------------------------------------------===//
152150

153-
def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding", "shared_encoding"> {
154-
let mnemonic = "shared";
151+
def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
152+
let cppNamespace = "::mlir::triton::gpu";
153+
154+
let description = [{
155+
Common trait describing shared memory.
156+
}];
157+
let methods = [
158+
InterfaceMethod<"Return the default alignment for the layout.",
159+
"int32_t",
160+
"getAlignment">,
161+
];
162+
}
163+
164+
def SwizzledSharedEncodingAttr :
165+
TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
166+
let mnemonic = "swizzled_shared";
155167

156168
let description = [{
157169
An encoding for tensors whose elements may be simultaneously accessed by
@@ -226,13 +238,6 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
226238
(r,c) has value
227239

228240
((c / 2) ^ r) * 2 + (c % 2).
229-
230-
For MMAv3 eg Hopper GMMA, hasLeadingOffset should be true. In this case,
231-
when the matrix is stored in shared memory, there will be an offset not
232-
only in the stride dimension, but also in the leading dimension. For example,
233-
a matrix of size 16x128 and data type I8 is stored in the shared memory with
234-
64B-swizzle mode. The offset of the element with index (0, 64) will be 16*64,
235-
compared to 1*64 when the hasLeadingOffset is false.
236241
}];
237242

238243
// swizzle info: vec, perPhase, maxPhase
@@ -243,20 +248,10 @@ compared to 1*64 when the hasLeadingOffset is false.
243248
"unsigned":$perPhase,
244249
"unsigned":$maxPhase,
245250
ArrayRefParameter<"unsigned">:$order,
246-
"CTALayoutAttr":$CTALayout,
247-
"bool":$hasLeadingOffset
251+
"CTALayoutAttr":$CTALayout
248252
);
249253

250254
let builders = [
251-
AttrBuilder<(ins "unsigned":$vec,
252-
"unsigned":$perPhase,
253-
"unsigned":$maxPhase,
254-
"ArrayRef<unsigned>":$order,
255-
"CTALayoutAttr":$CTALayout), [{
256-
bool hasLeadingOffset = false; // default value
257-
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, hasLeadingOffset);
258-
}]>,
259-
260255
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
261256
"ArrayRef<int64_t>":$shape,
262257
"ArrayRef<unsigned>":$order,
@@ -267,7 +262,7 @@ compared to 1*64 when the hasLeadingOffset is false.
267262
}]>,
268263

269264
// TODO(jlebar): This should not be an overload of
270-
// SharedEncodingAttr::get(). It's misleading, because it does a bunch of
265+
// SwizzledSharedEncodingAttr::get(). It's misleading, because it does a bunch of
271266
// nontrivial work based on the given dotOpEnc.
272267
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
273268
"ArrayRef<int64_t>":$shape,
@@ -402,38 +397,66 @@ compared to 1*64 when the hasLeadingOffset is false.
402397
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
403398
return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans);
404399
}]>,
400+
];
405401

402+
let extraClassDeclaration = extraBaseClassDeclaration # [{
403+
int32_t getAlignment() const;
404+
SmallVector<unsigned> getCTAsPerCGA() const;
405+
SmallVector<unsigned> getCTAOrder() const;
406+
SmallVector<unsigned> getCTASplitNum() const;
407+
}];
408+
let hasCustomAssemblyFormat = 1;
409+
}
410+
411+
def NVMMASharedEncodingAttr :
412+
TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
413+
let mnemonic = "nvmma_shared";
414+
415+
let description = [{
416+
Represent blocked shared memory matching MMAv3/MMAv5 shared memory input.
417+
This is meant to represent 2d tiled blocked layout.
418+
The full layout representation is described here:
419+
https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout
420+
}];
421+
422+
let parameters = (
423+
ins
424+
"unsigned":$swizzlingByteWidth,
425+
"bool":$transposed,
426+
"CTALayoutAttr":$CTALayout
427+
);
428+
429+
let builders = [
406430
AttrBuilder<(ins "ArrayRef<int64_t>":$shape,
407431
"ArrayRef<unsigned>":$order,
408432
"CTALayoutAttr":$CTALayout,
409433
"Type":$eltTy), [{
410434
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
411-
435+
int32_t swizzlingByteWidth = 0;
412436
int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth();
413-
int32_t vec = 128 / eleBitWidth, perPhase = 1, maxPhase = 1;
414437

415438
// get proper shared memory swizzling mode from the contiguous dimension
416439
// size of the origin blocked layout.
417440
auto contigDimSizeInByte = shapePerCTA[order[0]] * eleBitWidth / 8;
418441
if (contigDimSizeInByte >= 128 && contigDimSizeInByte % 128 == 0) {
419-
perPhase = 1;
420-
maxPhase = 8;
442+
swizzlingByteWidth = 128;
421443
} else if (contigDimSizeInByte >= 64 && contigDimSizeInByte % 64 == 0) {
422-
perPhase = 2;
423-
maxPhase = 4;
444+
swizzlingByteWidth = 64;
424445
} else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) {
425-
perPhase = 4;
426-
maxPhase = 2;
446+
swizzlingByteWidth = 32;
427447
} else {
428448
llvm_unreachable("unsupported shared memory layout for MMAv3");
429449
}
430-
431-
return $_get(context, vec, perPhase, maxPhase, order, CTALayout, true);
450+
bool transposed = order[0] == 0;
451+
return $_get(context, swizzlingByteWidth, transposed, CTALayout);
432452
}]>
433453
];
434454

435455
let extraClassDeclaration = extraBaseClassDeclaration # [{
436456
int32_t getAlignment() const;
457+
SmallVector<unsigned> getCTAsPerCGA() const;
458+
SmallVector<unsigned> getCTAOrder() const;
459+
SmallVector<unsigned> getCTASplitNum() const;
437460
}];
438461
let hasCustomAssemblyFormat = 1;
439462
}
@@ -468,16 +491,17 @@ We call each individual tile "rep".
468491
InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
469492
"SmallVector<unsigned>",
470493
"getRepOrder">,
471-
472-
// Interface for the meta information about the multiple thread hierarchy.
473-
InterfaceMethod<"Get the shape of the CTAs per CGA.",
474-
"SmallVector<unsigned>",
475-
"getCTAsPerCGA">,
476-
477-
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
494+
InterfaceMethod<"Return total element size per thread.",
495+
"unsigned",
496+
"getTotalElemsPerThread",
497+
(ins "ArrayRef<int64_t>":$tensorShape,
498+
"Type":$eltTy)>,
499+
InterfaceMethod<"Return element size per thread in each dimension.",
478500
"SmallVector<unsigned>",
479-
"getCTAOrder">,
480-
501+
"getElemsPerThread",
502+
(ins "ArrayRef<int64_t>":$tensorShape,
503+
"Type":$eltTy)>,
504+
// Interface for the meta information about the multiple thread hierarchy.
481505
InterfaceMethod<"Get the shape of the warps per CTA.",
482506
"SmallVector<unsigned>",
483507
"getWarpsPerCTA">,
@@ -498,10 +522,6 @@ We call each individual tile "rep".
498522
"SmallVector<unsigned>",
499523
"getSizePerThread">,
500524

501-
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
502-
"SmallVector<unsigned>",
503-
"getCTASplitNum">,
504-
505525
InterfaceMethod<"Gets the number of contiguous elements per thread.",
506526
"SmallVector<unsigned>",
507527
"getContigPerThread">,
@@ -514,7 +534,7 @@ We call each individual tile "rep".
514534

515535
class DistributedEncoding<string name, string attrMnemonic, list<Trait> traits = [],
516536
Dialect dialect = TritonGPU_Dialect>
517-
: TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait], traits), dialect> {
537+
: TritonGPU_Attr<name, attrMnemonic, !listconcat([DistributedEncodingTrait, LayoutEncodingTrait], traits), dialect> {
518538

519539
let description = [{
520540
Distributed encodings have a layout function L that is entirely characterized
@@ -550,6 +570,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
550570
}];
551571

552572
code extraDistributedDeclaration = extraBaseClassDeclaration # [{
573+
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
574+
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
553575
SmallVector<unsigned> getRepOrder() const;
554576
SmallVector<unsigned> getCTAsPerCGA() const;
555577
SmallVector<unsigned> getCTAOrder() const;

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class LoadOp;
1717
class StoreOp;
1818
class FuncOp;
1919
namespace gpu {
20-
class SharedEncodingAttr;
20+
class SwizzledSharedEncodingAttr;
2121
}
2222
} // namespace triton
2323

@@ -197,7 +197,7 @@ int getNVIDIAComputeCapability(Operation *module);
197197
// Read the amd target from the module attributes
198198
StringRef getAMDArch(Operation *module);
199199

200-
std::optional<mlir::triton::gpu::SharedEncodingAttr>
200+
std::optional<mlir::triton::gpu::SwizzledSharedEncodingAttr>
201201
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);
202202

203203
enum class MMALoadType {

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter", [DeclareOpInterfaceMet
297297
The `ttng.async_tma_scatter` operation scatters multiple separately-indexed
298298
rows of data from local memory into global memory asynchronously. The
299299
operation scatters a 2D tensor in shared memory, laid out by core tensor
300-
tiles (`hasLeadingOffset=true`) into separately indexed rows in global
300+
tiles nvmma_shared layout into separately indexed rows in global
301301
memory at a given `y` offset.
302302
}];
303303

lib/Analysis/Allocation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
165165
auto dstTy = cvtLayout.getType();
166166
auto srcEncoding = srcTy.getEncoding();
167167
auto dstEncoding = dstTy.getEncoding();
168-
if (isa<gpu::SharedEncodingAttr>(srcEncoding) ||
169-
isa<gpu::SharedEncodingAttr>(dstEncoding)) {
168+
if (mlir::isa<gpu::SharedEncodingTrait>(srcEncoding) ||
169+
mlir::isa<gpu::SharedEncodingTrait>(dstEncoding)) {
170170
// Conversions from/to shared memory do not need scratch memory.
171171
return 0;
172172
}

0 commit comments

Comments
 (0)