Skip to content

Commit b9b7dc0

Browse files
authored
[TMA] Split large dimensions up into chunks of 256 elements (#6776)
This overcomes the hardware limit that no dimension can exceed 256 within the block size of the TMA descriptor.
1 parent 62cf173 commit b9b7dc0

File tree

16 files changed

+255
-157
lines changed

16 files changed

+255
-157
lines changed

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,30 @@ triton::gpu::SharedEncodingTrait
3939
getEncodingFromDescriptor(Operation *op, RankedTensorType tensorType,
4040
Value desc);
4141

42-
int64_t getTMAContigDim(Attribute encoding, ArrayRef<int64_t> shape);
42+
SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
43+
int elementBitWidth, int swizzleBytes,
44+
bool fp4Padded, bool transposed,
45+
bool packedSize);
46+
47+
inline SmallVector<int64_t> getTMABlockShape(Attribute encoding,
48+
ArrayRef<int64_t> shapePerCTA,
49+
bool packedSize) {
50+
auto mmaEnc = cast<gpu::NVMMASharedEncodingAttr>(encoding);
51+
return getTMABlockShape(shapePerCTA, mmaEnc.getElementBitWidth(),
52+
mmaEnc.getSwizzlingByteWidth(), mmaEnc.getFp4Padded(),
53+
mmaEnc.getTransposed(), packedSize);
54+
}
4355

44-
inline int64_t getTMAContigDim(RankedTensorType tensorType) {
45-
return getTMAContigDim(tensorType.getEncoding(), tensorType.getShape());
56+
inline SmallVector<int64_t> getTMABlockShape(RankedTensorType ty,
57+
bool packedSize) {
58+
auto shapePerCTA = gpu::getShapePerCTA(ty);
59+
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize);
4660
}
4761

48-
inline int64_t getTMAContigDim(gpu::MemDescType memDescType) {
49-
return getTMAContigDim(memDescType.getEncoding(), memDescType.getShape());
62+
inline SmallVector<int64_t> getTMABlockShape(triton::gpu::MemDescType ty,
63+
bool packedSize) {
64+
auto shapePerCTA = gpu::getShapePerCTA(ty);
65+
return getTMABlockShape(ty.getEncoding(), shapePerCTA, packedSize);
5066
}
5167

5268
std::optional<int> getTMASwizzleMode(Operation *op, TensorDescType ty);
@@ -74,16 +90,18 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
7490

7591
int paddingScale = fp4Padded ? 2 : 1;
7692
auto shapePerCTA = gpu::getShapePerCTA(encoding, op.getTensorShape());
77-
int32_t contig_dim_size = getTMAContigDim(encoding, op.getTensorShape());
93+
auto blockShape =
94+
getTMABlockShape(encoding, shapePerCTA, /*packedSize=*/false);
95+
auto contigDimSize = blockShape.back();
7896

7997
llvm::SmallVector<Value> boxDim;
80-
if (fp4Padded && contig_dim_size != 128) {
98+
if (fp4Padded && contigDimSize != 128) {
8199
return op->emitError(
82100
"FP4 padded loads require 128 elements or more in the last dim");
83101
}
84-
boxDim.push_back(mkI32Constant(contig_dim_size));
102+
boxDim.push_back(mkI32Constant(contigDimSize));
85103
for (int k = shapePerCTA.size() - 2; k >= 0; --k)
86-
boxDim.push_back(mkI32Constant(shapePerCTA[k]));
104+
boxDim.push_back(mkI32Constant(blockShape[k]));
87105

88106
unsigned swizzleBytes = mmaEncoding ? mmaEncoding.getSwizzlingByteWidth() : 0;
89107
if (!mmaEncoding) {

include/triton/Tools/LayoutUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ LinearLayout ensureLayoutNotSmallerThan(
8383
const LinearLayout &layout,
8484
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);
8585

86+
inline LinearLayout
87+
ensureLayoutNotSmallerThan(const LinearLayout &layout,
88+
const llvm::ArrayRef<StringAttr> dimNames,
89+
const llvm::ArrayRef<int64_t> shape) {
90+
llvm::SmallDenseMap<StringAttr, int64_t> namedDims;
91+
for (auto [dimName, length] : llvm::zip_equal(dimNames, shape))
92+
namedDims[dimName] = length;
93+
assert(namedDims.size() == shape.size() && "duplicate dimension names given");
94+
return ensureLayoutNotSmallerThan(layout, namedDims);
95+
}
96+
8697
// Return a vector of the standard out dimension names for tensor layouts. These
8798
// are "dim0", "dim1", etc.
8899
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);

include/triton/Tools/LinearLayout.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,23 +325,32 @@ class LinearLayout {
325325
bases;
326326

327327
llvm::MapVector<StringAttr, int32_t /*size*/> outDims;
328-
bool surjective;
328+
bool surjective = true;
329329

330330
public:
331331
using BasesT = decltype(bases);
332332

333+
LinearLayout() = default;
334+
333335
// The 0-dimensional layout that maps everything to 0. This is useful as a
334336
// starting point when doing something like
335337
//
336338
// LinearLayout ret = LinearLayout::empty();
337339
// for (...) ret *= ...;
338340
// return ret;
339-
static LinearLayout empty() { return LinearLayout(BasesT{}, {}); }
341+
static LinearLayout empty() { return {}; }
342+
343+
// Creates a 1D -> 1D layout that's the function L(x) = stride * x
344+
// for x in [0, size).
345+
static LinearLayout strided1D(int32_t size, int32_t stride, StringAttr inDim,
346+
StringAttr outDim);
340347

341348
// Creates a 1D -> 1D layout that's the identity function, i.e. L(x) = x
342349
// for x in [0, size).
343350
static LinearLayout identity1D(int32_t size, StringAttr inDim,
344-
StringAttr outDim);
351+
StringAttr outDim) {
352+
return strided1D(size, /*stride=*/1, inDim, outDim);
353+
}
345354

346355
// Creates a 1D -> 1D layout that maps every input value to 0, i.e. L(x) = 0
347356
// for x in [0, size). By default this creates a surjective layout where

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
77
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
88
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
9+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
910
#include "triton/Tools/LayoutUtils.h"
1011
#include "triton/Tools/LinearLayout.h"
1112
#include "triton/Tools/StrUtil.h"
@@ -241,7 +242,6 @@ LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
241242
int tileRows = 8;
242243
int tileCols = 8 * tileWidthBytes / elemBitWidth;
243244
bool isFp4Padded = shared.getFp4Padded();
244-
int packingFactor = isFp4Padded ? 2 : 1;
245245

246246
std::vector<std::vector<int>> bases2D;
247247
for (int col = 1; col < tileCols; col *= 2) {
@@ -269,11 +269,7 @@ LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
269269
}
270270
}
271271
auto outDimNames = standardOutDimNames(ctx, 2);
272-
auto kRow = outDimNames[1];
273-
auto kCol = outDimNames[0];
274-
LinearLayout tileLayout =
275-
LinearLayout({{S("offset"), bases2D}}, {kRow, kCol});
276-
return tileLayout;
272+
return LinearLayout({{S("offset"), bases2D}}, outDimNames);
277273
}
278274

279275
} // namespace
@@ -285,63 +281,62 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
285281
int rank = shape.size();
286282
auto shapePerCTA = getShapePerCTA(shared, shape);
287283
auto kOffset = S("offset");
284+
auto tmaShape = triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA,
285+
/*packedSize=*/true);
288286
if (shared.getSwizzlingByteWidth() == 0) {
289287
auto outDimNames = standardOutDimNames(ctx, rank);
290-
LinearLayout layout = LinearLayout::identity1D(
291-
shapePerCTA[rank - 1], kOffset, outDimNames[rank - 1]);
288+
LinearLayout layout = LinearLayout::identity1D(tmaShape[rank - 1], kOffset,
289+
outDimNames[rank - 1]);
292290
for (int i = rank - 2; i >= 0; --i) {
293-
layout *=
294-
LinearLayout::identity1D(shapePerCTA[i], kOffset, outDimNames[i]);
291+
layout *= LinearLayout::identity1D(tmaShape[i], kOffset, outDimNames[i]);
295292
}
293+
layout = ensureLayoutNotSmallerThan(layout, outDimNames, shapePerCTA);
296294
return combineCtaCgaWithShape(layout, shared.getCTALayout(), shape);
297295
}
298296
assert(rank >= 2);
299297

300298
// Collapse all the outer dim into one. We will then create a layout for this
301299
// shape and reshape it to the original shape.
302-
std::array<int64_t, 2> collapsedShapePerCTA{1, shapePerCTA.back()};
300+
std::array<int64_t, 2> collapsedTmaShape{1, tmaShape.back()};
303301
for (int i = 0; i + 1 < rank; i++)
304-
collapsedShapePerCTA[0] *= shapePerCTA[i];
302+
collapsedTmaShape[0] *= tmaShape[i];
305303
if (shared.getTransposed()) {
306-
std::swap(collapsedShapePerCTA[0], collapsedShapePerCTA[1]);
304+
std::swap(collapsedTmaShape[0], collapsedTmaShape[1]);
307305
}
308306

309307
auto tileLayout = getCoreMatrixLinearLayout(shared, disableSwizzle);
310308
auto outDimNames = standardOutDimNames(ctx, 2);
311-
auto kRow = outDimNames[1];
312-
auto kCol = outDimNames[0];
309+
auto kRow = outDimNames[0];
310+
auto kCol = outDimNames[1];
313311
auto tileRows = tileLayout.getOutDimSize(kRow);
314312
auto tileCols = tileLayout.getOutDimSize(kCol);
315313

316314
int packingFactor = shared.getFp4Padded() ? 2 : 1;
317-
if (collapsedShapePerCTA[1] * packingFactor < tileCols ||
318-
collapsedShapePerCTA[0] < tileRows) {
315+
if (collapsedTmaShape[1] * packingFactor < tileCols ||
316+
collapsedTmaShape[0] < tileRows) {
319317
llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA to "
320318
"be at least ["
321319
<< tileRows << ", " << (tileCols / packingFactor)
322-
<< "], collapsedShapePerCTA: [" << collapsedShapePerCTA[0]
323-
<< ", " << collapsedShapePerCTA[1] << "]\n";
320+
<< "], collapsedTmaShape: [" << collapsedTmaShape[0] << ", "
321+
<< collapsedTmaShape[1] << "]\n";
324322
llvm::report_fatal_error("Illegal shared layout");
325323
}
326324

327325
// Distribute the remaining rows and cols.
328-
auto layout = tileLayout;
329-
layout *= LinearLayout::identity1D(collapsedShapePerCTA[0] / tileRows,
330-
kOffset, kRow);
331-
layout *= LinearLayout::identity1D(collapsedShapePerCTA[1] / tileCols,
332-
kOffset, kCol);
326+
auto layout =
327+
ensureLayoutNotSmallerThan(tileLayout, outDimNames, collapsedTmaShape);
333328

334329
// Reshape the layout to the N-D pre-transposed shape per CTA.
335-
SmallVector<int64_t> maybeTransposedShapePerCTA = shapePerCTA;
330+
SmallVector<int64_t> maybeTransposedTmaShape = tmaShape;
336331
if (shared.getTransposed()) {
337332
// Move the outer dim to the inner position.
338333
// TODO: we should move back to using `order` instead of transposed to make
339334
// the order more explicit.
340-
std::rotate(maybeTransposedShapePerCTA.begin(),
341-
maybeTransposedShapePerCTA.begin() + 1,
342-
maybeTransposedShapePerCTA.end());
335+
std::rotate(maybeTransposedTmaShape.begin(),
336+
maybeTransposedTmaShape.begin() + 1,
337+
maybeTransposedTmaShape.end());
343338
}
344-
auto reshapedLayout = reshapeLayout(ctx, layout, maybeTransposedShapePerCTA);
339+
auto reshapedLayout = reshapeLayout(ctx, layout, maybeTransposedTmaShape);
345340

346341
if (shared.getTransposed()) {
347342
SmallVector<int> order = {rank - 1};
@@ -351,6 +346,9 @@ LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
351346
reshapedLayout = transposeLinearLayout(reshapedLayout, order);
352347
}
353348

349+
reshapedLayout = ensureLayoutNotSmallerThan(
350+
reshapedLayout, standardOutDimNames(ctx, shapePerCTA.size()),
351+
shapePerCTA);
354352
return combineCtaCgaWithShape(reshapedLayout, shared.getCTALayout(), shape);
355353
}
356354

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
1212
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1313
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
14+
#include "triton/Tools/LayoutUtils.h"
1415
#include "triton/Tools/LinearLayout.h"
1516
#include <memory>
1617

@@ -149,29 +150,36 @@ static Attribute inferSrcEncodingMemDescReshape(Attribute dstEncoding,
149150
ArrayRef<int64_t> dstShape) {
150151
auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(dstEncoding);
151152
if (!mmaEncoding)
152-
return Attribute();
153+
return {};
153154
// TODO: supporting reshape of CTA layouts is non-trivial.
154155
if (getNumCTAs(mmaEncoding) > 1)
155-
return Attribute();
156+
return {};
156157
int innerDimDst =
157158
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
158159
int innerDimSrc =
159160
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
160161
// For now disallow reshape of the inner dimension.
161162
if (innerDimDst != innerDimSrc)
162-
return Attribute();
163+
return {};
163164

164165
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
165166
auto CTALayout = CTALayoutAttr::get(
166167
dstEncoding.getContext(),
167168
/*CTAsPerCGA=*/SmallVector<unsigned>(srcShape.size(), 1),
168169
/*CTASplitNum=*/SmallVector<unsigned>(srcShape.size(), 1),
169170
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(srcShape.size())));
170-
// Check that the second dim is big enough to contain a full swizzle.
171-
return NVMMASharedEncodingAttr::get(
171+
auto srcEncoding = NVMMASharedEncodingAttr::get(
172172
dstEncoding.getContext(), mmaEncoding.getSwizzlingByteWidth(),
173173
mmaEncoding.getTransposed(), mmaEncoding.getElementBitWidth(),
174174
mmaEncoding.getFp4Padded(), CTALayout);
175+
// Big guns, check linear layouts are equivalent
176+
auto srcLL = toLinearLayout(srcShape, srcEncoding);
177+
auto dstLL = toLinearLayout(dstShape, dstEncoding);
178+
auto ctx = dstEncoding.getContext();
179+
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
180+
return {};
181+
}
182+
return srcEncoding;
175183
}
176184

177185
// Rewrite
@@ -315,12 +323,6 @@ class UseShmemForScales
315323
if (!isInnermostContiguous(scaleType, 512))
316324
return false;
317325

318-
auto sharedEnc =
319-
dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(scaleType.getEncoding());
320-
if (!sharedEnc || sharedEnc.getTransposed() || sharedEnc.getFp4Padded() ||
321-
sharedEnc.getSwizzlingByteWidth() != 0)
322-
return false;
323-
324326
if (usesTMAload) {
325327
return true;
326328
}

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,6 @@ Attribute findLoadEncodingFromUsers(Operation *op) {
5757
return {};
5858
}
5959

60-
ttg::CTALayoutAttr getCtaLayoutFromEncoding(Attribute encoding) {
61-
auto layout = cast<ttg::LayoutEncodingTrait>(encoding);
62-
auto ctx = encoding.getContext();
63-
return ttg::CTALayoutAttr::get(ctx, layout.getCTAsPerCGA(),
64-
layout.getCTASplitNum(), layout.getCTAOrder());
65-
}
66-
6760
SmallVector<int64_t> expandToRank(ArrayRef<int64_t> shape, int rank) {
6861
SmallVector<int64_t> result(rank, 1);
6962
assert(shape.size() <= rank);

lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,23 +99,30 @@ ttg::SharedEncodingTrait getEncodingFromDescriptor(Operation *op,
9999
return updateEncodingForShape(op, sharedEnc, tensorType);
100100
}
101101

102-
int64_t getTMAContigDim(Attribute encoding, ArrayRef<int64_t> shape) {
103-
assert(encoding);
104-
auto mmaEncoding =
105-
llvm::dyn_cast_or_null<ttg::NVMMASharedEncodingAttr>(encoding);
106-
107-
// The bounding box inner dimension must be less than or equal to the
108-
// swizzle size.
109-
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7
110-
// We clamp the block size and the codegen will emit multiple copy
111-
// operations.
112-
if (mmaEncoding && mmaEncoding.getSwizzlingByteWidth() != 0) {
113-
auto elemSize = mmaEncoding.getElementBitWidth() / 8;
114-
return mmaEncoding.getSwizzlingByteWidth() / elemSize;
102+
SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
103+
int elementBitWidth, int swizzleBytes,
104+
bool fp4Padded, bool isTransposed,
105+
bool packedSize) {
106+
SmallVector<int64_t> blockShape(shapePerCTA);
107+
int contigDim = isTransposed ? 0 : blockShape.size() - 1;
108+
if (fp4Padded) {
109+
blockShape[contigDim] *= 2;
115110
}
116-
117-
auto shapePerCTA = ttg::getShapePerCTA(encoding, shape);
118-
return shapePerCTA.back();
111+
// All dimensions must be at most 256
112+
constexpr int64_t dimMax = 256;
113+
for (auto &size : blockShape) {
114+
size = std::min(size, dimMax);
115+
}
116+
// Last dim must equal the swizzle byte size
117+
if (swizzleBytes != 0) {
118+
auto contigDimSize = (8 * swizzleBytes) / elementBitWidth;
119+
assert(blockShape[contigDim] >= contigDimSize);
120+
blockShape[contigDim] = contigDimSize;
121+
}
122+
if (fp4Padded && packedSize) {
123+
blockShape[contigDim] /= 2;
124+
}
125+
return blockShape;
119126
}
120127

121128
std::optional<int> getTMASwizzleMode(Operation *op, TensorDescType ty) {

0 commit comments

Comments
 (0)