Skip to content

Commit 5d8325f

Browse files
plognjenoplavsic
andauthored
[AMD] Support MI350 ds_read_b64_tr_b8 instruction for int8 (#6018)
This commit adds support for the ds_read_b64_tr_b8 instruction. Currently, we only enable it for int8 element types. It should also work for fp8, which will be turned on later with more testing. --------- Co-authored-by: Ognjen Plavsic <[email protected]>
1 parent 5c05106 commit 5d8325f

File tree

5 files changed

+412
-172
lines changed

5 files changed

+412
-172
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
264264

265265
// The primary goal of this function is to efficiently load 2D tiles of a
266266
// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs.
267-
LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
268-
int32_t elemBitWidth);
267+
LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
268+
int32_t elemBitWidth);
269269

270270
// Create LinearLayout for mxfp4 and mxfp8 operand in scaled mfma.
271271
// For mxfp4, we use dot layout directly. Mxfp8 is not covered by dot

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 81 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -393,12 +393,12 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
393393
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
394394
}
395395

396-
LinearLayout chooseDotDsReadB64Tr16Layout(DotOperandEncodingAttr dotMfmaLayout,
397-
ArrayRef<int64_t> shape,
398-
int32_t elemBitWidth) {
396+
LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
397+
ArrayRef<int64_t> shape,
398+
int32_t elemBitWidth) {
399399
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
400400
assert(mfmaLayout.getMDim() == 16 || mfmaLayout.getNDim() == 32);
401-
assert(elemBitWidth == 16);
401+
assert(elemBitWidth == 16 || elemBitWidth == 8);
402402

403403
auto rank = shape.size();
404404
bool hasBatchDim = rank == 3;
@@ -407,6 +407,7 @@ LinearLayout chooseDotDsReadB64Tr16Layout(DotOperandEncodingAttr dotMfmaLayout,
407407
// loads for most element sizes (16b, 8b, 4b).
408408
const int32_t ldsReadWidth = 64;
409409
int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
410+
const int elemByteWidth = elemBitWidth / 8;
410411
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
411412

412413
int32_t kSize = shape[kDim];
@@ -427,72 +428,92 @@ LinearLayout chooseDotDsReadB64Tr16Layout(DotOperandEncodingAttr dotMfmaLayout,
427428
SmallVector<unsigned> order = dotMfmaLayout.getDefaultOrder();
428429
std::swap(order[0], order[1]);
429430

430-
// In the LDS transpose logic, each thread accesses 64 bits (8 bytes) of data.
431-
// The smallest unit for transposing is a 4x4 sub-tile of threads, where each
432-
// thread reads 4 16-bit elements along the non-K dimension, resulting in a
433-
// [non-K, K] = {16, 4} sub-tile of elements. Because of transposing
434-
// mechanism, thread ends up with 4 16-bit elements along K dim.
431+
// For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes)
432+
// of data. The smallest unit for transposition is a
433+
// [non-K, K] = {16, kWidthTransRead} sub-tile of elements,
434+
// where each thread reads kWidthTransRead elements along the non-K dimension.
435+
// Due to the transposition mechanism, each thread ends up with
436+
// kWidthTransRead elements along the K dimension.
435437
//
436438
// The MFMA selection logic prioritizes double-rate MFMA instructions whenever
437-
// possible. Specifically:
438-
// - For MFMA operations that are non-K = 16, when blockK > 16, mfma16x16x32
439-
// is selected; otherwise (blockK ≤ 16), mfma16x16x16 remains the choice.
440-
// - For MFMA operations that are non-K = 32, when blockK > 8, mfma32x32x16 is
441-
// selected; otherwise (blockK ≤ 8), mfma32x32x8 is used.
439+
// possible:
442440
//
443-
// In double-rate MFMA instructions, each thread holds 8 elements along the K
444-
// dimension.
445-
// - The first 4 elements belong to the first sub-tile.
446-
// - The next 4 elements belong to the second sub-tile.
441+
// - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k
442+
// is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice.
447443
//
448-
// We then group these into larger tiles, each consisting of 8 of these 16x4
449-
// sub-tiles. These tiles correspond to data for one mfma instruction. The
450-
// shapes of these tiles depend on the MFMA instruction used:
451-
// 1. For mfma32x32x16, the tile shape is [non-K, K] = {32, 16}.
452-
// 2. For mfma16x16x32, the tile shape is [non-K, K] = {16, 32}.
444+
// - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is
445+
// selected; otherwise (blockK ≤ k), mfma32x32xk is used.
453446
//
454-
// For single-rate mfma instructions, each thread holds 4 elements along K
455-
// dimension. This means larger tile (that corresponds to one mfma
456-
// instruction) consists of 4 16x4 sub-tiles.
457-
std::vector<std::vector<int32_t>> registerBase = {{1, 0},
458-
{2, 0}}; // first sub-tile
459-
std::vector<std::vector<int32_t>> laneBase = {{kWidthTransRead, 0},
460-
{2 * kWidthTransRead, 0},
461-
{0, 1},
462-
{0, 2}}; // first sub-tile
463-
464-
// Extend register base for multiple tiles in K dimension (corresponding to
465-
// multiple mfma instructions accross k dim).
466-
auto populateRegisterBase = [&](int kTileSize) {
467-
const int regsPerTile = 8;
468-
int numRegs = (kSize / kTileSize) * regsPerTile;
469-
for (int reg = regsPerTile; reg < numRegs; reg *= 2) {
447+
// NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA
448+
// instructions are used.
449+
//
450+
// In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead
451+
// elements along the K dimension:
452+
// - The first kWidthTransRead elements belong to the first sub-tile.
453+
// - The next kWidthTransRead elements belong to the second sub-tile.
454+
//
455+
// These elements are then grouped into larger tiles, each consisting of
456+
// 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data
457+
// for one MFMA instruction. The shape of these tiles depends on the MFMA
458+
// instruction used.
459+
//
460+
// For single-rate MFMA instructions, each thread holds kWidthTransRead
461+
// elements along the K dimension. This means that the larger tile
462+
// (corresponding to one MFMA instruction) consists of 4 {16, kWidthTransRead}
463+
// sub-tiles.
464+
std::vector<std::vector<int32_t>> registerBase;
465+
std::vector<std::vector<int32_t>> laneBase;
466+
467+
// Populate register base for first subtile
468+
for (int i = 1; i < kWidthTransRead; i *= 2) {
469+
registerBase.push_back({i, 0});
470+
}
471+
472+
const int threadsPerSubtileNonK = 16 / kWidthTransRead;
473+
const int threadsPerSubtileK = kWidthTransRead;
474+
475+
// Populate lane base for first subtile
476+
for (int i = 1; i < threadsPerSubtileNonK; i *= 2) {
477+
laneBase.push_back({i * kWidthTransRead, 0});
478+
}
479+
for (int i = 1; i < threadsPerSubtileK; i *= 2) {
480+
laneBase.push_back({0, i});
481+
}
482+
483+
// Function to extend register base for multiple tiles K dim.
484+
auto extendRegisterBaseForKDim = [&](int kTileSize) {
485+
const int regsPerTile = kWidthTransRead * 2; // Two subtiles per tile
486+
int totalRegs = (kSize / kTileSize) * regsPerTile;
487+
488+
for (int reg = regsPerTile; reg < totalRegs; reg *= 2) {
470489
registerBase.push_back({0, (reg / regsPerTile) * kTileSize});
471490
}
472491
};
473492

474493
const bool isMfma32 = (mfmaLayout.getMDim() == 32);
475494
const bool isMfma16 = (mfmaLayout.getMDim() == 16);
476-
const int kTileSize = isMfma32 ? 16 : 32;
477-
478-
if (kSize >= kTileSize) {
479-
// Handles mfma32x32x16 and mfma16x16x32 cases
480-
assert(kWidthDot == 8);
481-
registerBase.push_back({0, 4}); // second sub-tile
482-
populateRegisterBase(kTileSize);
483-
auto laneBaseExt = isMfma32
484-
? std::vector<std::vector<int32_t>>{{16, 0}, {0, 8}}
485-
: std::vector<std::vector<int32_t>>{{0, 8}, {0, 16}};
486-
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
495+
const int kTileSize = isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
496+
const bool largeKSize = kSize >= kTileSize;
497+
498+
// Extend register base for large K sizes.
499+
if (largeKSize) {
500+
registerBase.push_back({0, threadsPerSubtileK}); // Second subtile
501+
extendRegisterBaseForKDim(kTileSize);
502+
}
503+
504+
// Extend lane base based on MFMA size.
505+
const int numSubtilesPerTile = largeKSize ? 2 : 1;
506+
std::vector<std::vector<int32_t>> laneBaseExt;
507+
508+
if (isMfma32) {
509+
laneBaseExt = {{16, 0}, {0, numSubtilesPerTile * threadsPerSubtileK}};
487510
} else {
488-
// Handles mfma32x32x8 and mfma16x16x16 cases
489-
assert(kWidthDot == 4);
490-
auto laneBaseExt = isMfma32
491-
? std::vector<std::vector<int32_t>>{{16, 0}, {0, 4}}
492-
: std::vector<std::vector<int32_t>>{{0, 4}, {0, 8}};
493-
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
511+
laneBaseExt = {{0, numSubtilesPerTile * threadsPerSubtileK},
512+
{0, 2 * numSubtilesPerTile * threadsPerSubtileK}};
494513
}
495514

515+
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
516+
496517
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
497518
// To assign them to actual matrix dimensions `order` array is used.
498519
// For operand A: non-k-dim -> dim0, k-dim -> dim1
@@ -516,10 +537,7 @@ LinearLayout chooseDotDsReadB64Tr16Layout(DotOperandEncodingAttr dotMfmaLayout,
516537

517538
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
518539
warpLayout.transposeOuts(outDimNames);
519-
auto finalLayout =
520-
combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
521-
522-
return finalLayout;
540+
return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
523541
}
524542

525543
LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
@@ -1334,10 +1352,10 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
13341352
return chooseDotLdMatrixLayout(dot, shape, needTrans, elemBitWidth);
13351353
}
13361354

1337-
LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
1338-
int32_t elemBitWidth) {
1355+
LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
1356+
int32_t elemBitWidth) {
13391357
auto dot = cast<DotOperandEncodingAttr>(enc);
1340-
return chooseDotDsReadB64Tr16Layout(dot, shape, elemBitWidth);
1358+
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
13411359
}
13421360

13431361
LinearLayout

test/Conversion/amd/ds_transpose.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,72 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
7474
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>>
7575
tt.return
7676
}
77+
78+
// CHECK-LABEL: ds_transpose_n_t_i8_mfma_16
79+
tt.func @ds_transpose_n_t_i8_mfma_16(%arg0: !ttg.memdesc<128x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem, mutable>) {
80+
// CHECK-COUNT-16: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
81+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 16}>>
82+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared1, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 16}>>
83+
tt.return
84+
}
85+
86+
// CHECK-LABEL: ds_transpose_t_t_i8_mfma_16
87+
tt.func @ds_transpose_t_t_i8_mfma_16(%arg0: !ttg.memdesc<128x64xi8, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem, mutable>) {
88+
// CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8>
89+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared1, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 16}>>
90+
// CHECK-COUNT-8: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
91+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared1, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 16}>>
92+
tt.return
93+
}
94+
95+
// CHECK-LABEL: ds_transpose_n_n_i8_mfma_16
96+
tt.func @ds_transpose_n_n_i8_mfma_16(%arg0: !ttg.memdesc<128x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared, #smem, mutable>) {
97+
// CHECK-COUNT-8: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
98+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 16}>>
99+
// CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8>
100+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 16}>>
101+
tt.return
102+
}
103+
104+
// CHECK-LABEL: ds_transpose_t_n_i8_mfma_16
105+
tt.func @ds_transpose_t_n_i8_mfma_16(%arg0: !ttg.memdesc<128x64xi8, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared, #smem, mutable>) {
106+
// CHECK-NOT: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
107+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared1, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 16}>>
108+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 16}>>
109+
tt.return
110+
}
111+
112+
// CHECK-LABEL: ds_transpose_n_t_i8_mfma32
113+
tt.func @ds_transpose_n_t_i8_mfma32(%arg0: !ttg.memdesc<128x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem, mutable>) {
114+
// CHECK-COUNT-16: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
115+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>>
116+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared1, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>>
117+
tt.return
118+
}
119+
120+
// CHECK-LABEL: ds_transpose_t_t_i8_mfma32
121+
tt.func @ds_transpose_t_t_i8_mfma32(%arg0: !ttg.memdesc<128x64xi8, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared1, #smem, mutable>) {
122+
// CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8>
123+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared1, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>>
124+
// CHECK-COUNT-6: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
125+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared1, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>>
126+
tt.return
127+
}
128+
129+
// CHECK-LABEL: ds_transpose_n_n_i8_mfma32
130+
tt.func @ds_transpose_n_n_i8_mfma32(%arg0: !ttg.memdesc<128x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared, #smem, mutable>) {
131+
// CHECK-COUNT-8: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
132+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>>
133+
// CHECK-COUNT-4: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8>
134+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>>
135+
tt.return
136+
}
137+
138+
// CHECK-LABEL: ds_transpose_t_n_i8_mfma32
139+
tt.func @ds_transpose_t_n_i8_mfma32(%arg0: !ttg.memdesc<128x64xi8, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xi8, #shared, #smem, mutable>) {
140+
// CHECK-NOT: rocdl.ds.read.tr8.b64 %{{.*}} : <3> -> vector<2xi32>
141+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xi8, #shared1, #smem, mutable> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>>
142+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xi8, #shared, #smem, mutable> -> tensor<64x128xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>>
143+
tt.return
144+
}
77145
}

0 commit comments

Comments
 (0)