Skip to content

Commit 55db9b8

Browse files
authored
[AMD] Introduce LocalLoadPackedTransposedOp (#7422)
This operation is used to transpose packed tensors, which is necessary because the shape of these tensors changes when the packing changes. It will be used to load FP4 K packed K contiguous when they are stored in shared memory packed along M/N and M/N contiguous. The FP4 types are treated in LL as if they were i8 types as this is how they are also treated in the rest of the compiler pipeline. Note: at the moment the operation is introduced without being used (only tested). This is because there will be followup PRs for the rest of the functionality.
1 parent 60a1996 commit 55db9b8

File tree

7 files changed

+478
-109
lines changed

7 files changed

+478
-109
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 140 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -576,16 +576,21 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
576576
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
577577
auto mDim = mfmaLayout.getMDim();
578578
assert(mDim == 16 || mDim == 32);
579+
580+
bool isFP4 = false;
581+
if (elemBitWidth == 4) {
582+
// When doing ds_read_tr4 we actually write the LL as if it were on i8
583+
// elements this is becasue LL needs to be described for the i8 tensor
584+
// elements.
585+
elemBitWidth = 8;
586+
isFP4 = true;
587+
}
588+
579589
assert(elemBitWidth == 16 || elemBitWidth == 8);
580590

581591
auto rank = shape.size();
582592
bool hasBatchDim = rank == 3;
583593
int32_t kWidthDot = dotMfmaLayout.getKWidth();
584-
// Number of bits loaded by an LDS read. ds_read_tr primarily supports 64-bit
585-
// loads for most element sizes (16b, 8b, 4b).
586-
const int32_t ldsReadWidth = 64;
587-
int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
588-
const int elemByteWidth = elemBitWidth / 8;
589594
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
590595

591596
int32_t kSize = shape[kDim];
@@ -606,106 +611,151 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
606611
SmallVector<unsigned> order =
607612
getOrderForDotOperand(dotMfmaLayout.getOpIdx(), rank, /*kContig*/ false);
608613

609-
// For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes)
610-
// of data. The smallest unit for transposition is a
611-
// [non-K, K] = {16, kWidthTransRead} sub-tile of elements,
612-
// where each thread reads kWidthTransRead elements along the non-K dimension.
613-
// Due to the transposition mechanism, each thread ends up with
614-
// kWidthTransRead elements along the K dimension.
615-
//
616-
// The MFMA selection logic prioritizes double-rate MFMA instructions whenever
617-
// possible:
618-
//
619-
// - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k
620-
// is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice.
621-
//
622-
// - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is
623-
// selected; otherwise (blockK ≤ k), mfma32x32xk is used.
624-
//
625-
// NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA
626-
// instructions are used.
627-
//
628-
// In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead
629-
// elements along the K dimension:
630-
// - The first kWidthTransRead elements belong to the first sub-tile.
631-
// - The next kWidthTransRead elements belong to the second sub-tile.
632-
//
633-
// These elements are then grouped into larger tiles, each consisting of
634-
// 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data
635-
// for one MFMA instruction. The shape of these tiles depends on the MFMA
636-
// instruction used.
637-
//
638-
// For single-rate MFMA instructions, each thread holds kWidthTransRead
639-
// elements along the K dimension. This means that the larger tile
640-
// (corresponding to one MFMA instruction) consists of 4 {16, kWidthTransRead}
641-
// sub-tiles.
642614
std::vector<std::vector<int32_t>> registerBase;
643615
std::vector<std::vector<int32_t>> laneBase;
616+
auto populateFP4LL = [&registerBase, &laneBase](int kSize, int mDim) {
617+
const bool isMfma32 = (mDim == 32);
618+
// ds_read_b64_tr4 operates on FP4 values swapping the packing of them. Look
619+
// at i8 values for the ownership of register/lane since it's the data type
620+
// of the tensor. Register dimension: what i8 in the tile are held by thread
621+
// 0? Lane dimension: what i8 in the tile are held in register 0 of each
622+
// thread?
623+
registerBase.push_back({1, 0});
624+
registerBase.push_back({2, 0});
625+
registerBase.push_back({4, 0});
626+
registerBase.push_back({0, 16});
627+
628+
// If more than one tile needs to be loaded, populate registerBase
629+
// dimension for the other tiles
630+
const int kTileSize = isMfma32 ? 64 : 128;
631+
for (int reg = kTileSize; reg < kSize; reg *= 2) {
632+
registerBase.push_back({0, reg});
633+
}
644634

645-
// Populate register base for first subtile
646-
for (int i = 1; i < kWidthTransRead; i *= 2) {
647-
registerBase.push_back({i, 0});
648-
}
649-
650-
const int threadsPerSubtileNonK = 16 / kWidthTransRead;
651-
const int threadsPerSubtileK = kWidthTransRead;
652-
653-
// Populate lane base for first subtile
654-
for (int i = 1; i < threadsPerSubtileNonK; i *= 2) {
655-
laneBase.push_back({i * kWidthTransRead, 0});
656-
}
657-
for (int i = 1; i < threadsPerSubtileK; i *= 2) {
658-
laneBase.push_back({0, i});
659-
}
660-
661-
// Function to extend register base for multiple tiles K dim.
662-
auto extendRegisterBaseForKDim = [&](int kTileSize, int numSubtilesPerTile) {
663-
const int regsPerTile = kWidthTransRead * numSubtilesPerTile;
664-
int totalRegs = (kSize / kTileSize) * regsPerTile;
665-
666-
for (int reg = regsPerTile; reg < totalRegs; reg *= 2) {
667-
registerBase.push_back({0, (reg / regsPerTile) * kTileSize});
635+
// When mDim == 16 we have 16x128 mfma, otherwise it's 16x64
636+
// The LL for the two is different
637+
laneBase.push_back({0, 1});
638+
laneBase.push_back({0, 2});
639+
laneBase.push_back({0, 4});
640+
laneBase.push_back({0, 8});
641+
if (mDim == 16) {
642+
laneBase.push_back({0, 32});
643+
laneBase.push_back({0, 64});
644+
} else {
645+
assert(mDim == 32);
646+
laneBase.push_back({8, 0});
647+
laneBase.push_back({0, 32});
668648
}
669649
};
650+
auto populateLL = [&registerBase, &laneBase](int elemBitWidth, int kSize,
651+
int kWidthDot, int mDim) {
652+
// Number of bits loaded by an LDS read. ds_read_tr primarily supports
653+
// 64-bit loads for most element sizes (16b, 8b, 4b).
654+
const int32_t ldsReadWidth = 64;
655+
int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
656+
const int elemByteWidth = elemBitWidth / 8;
657+
const bool isMfma32 = (mDim == 32);
658+
659+
// For ds_read_b64_tr_* instructions, each thread accesses 64 bits (8 bytes)
660+
// of data. The smallest unit for transposition is a
661+
// [non-K, K] = {16, kWidthTransRead} sub-tile of elements,
662+
// where each thread reads kWidthTransRead elements along the non-K
663+
// dimension. Due to the transposition mechanism, each thread ends up with
664+
// kWidthTransRead elements along the K dimension.
665+
//
666+
// The MFMA selection logic prioritizes double-rate MFMA instructions
667+
// whenever possible:
668+
//
669+
// - For MFMA operations where M = N = 16, when blockK > k, mfma16x16x2*k
670+
// is selected; otherwise (blockK ≤ k), mfma16x16xk remains the choice.
671+
//
672+
// - For MFMA operations where M = N = 32, when blockK > k, mfma32x32x2*k is
673+
// selected; otherwise (blockK ≤ k), mfma32x32xk is used.
674+
//
675+
// NOTE: For fp8 and fp4, "double-rate" results in 4*k since scaled MFMA
676+
// instructions are used.
677+
//
678+
// In "double-rate" MFMA instructions, each thread holds 2*kWidthTransRead
679+
// elements along the K dimension:
680+
// - The first kWidthTransRead elements belong to the first sub-tile.
681+
// - The next kWidthTransRead elements belong to the second sub-tile.
682+
//
683+
// These elements are then grouped into larger tiles, each consisting of
684+
// 8 {16, kWidthTransRead} sub-tiles. These tiles correspond to the data
685+
// for one MFMA instruction. The shape of these tiles depends on the MFMA
686+
// instruction used.
687+
//
688+
// For single-rate MFMA instructions, each thread holds kWidthTransRead
689+
// elements along the K dimension. This means that the larger tile
690+
// (corresponding to one MFMA instruction) consists of 4 {16,
691+
// kWidthTransRead} sub-tiles.
692+
693+
// Populate register base for first subtile
694+
for (int i = 1; i < kWidthTransRead; i *= 2) {
695+
registerBase.push_back({i, 0});
696+
}
670697

671-
const bool isMfma32 = (mDim == 32);
672-
const bool isMfma16 = (mDim == 16);
673-
674-
// kDoubleTileSize is the k dimension of a tile when double rated
675-
// mfma instructions are used.
676-
const int kDoubleTileSize =
677-
isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
678-
// kTileSize is the actually k dimention of a tile, which is
679-
// determined by kWidthDot.
680-
const int kTileSize = kWidthDot * 64 / mDim;
681-
// We use kDoubleTileSize as a reference to check whether the given
682-
// kWidthDot leads to double or single sub-tiles in each tile.
683-
const int numSubtilesPerTile = (kTileSize == kDoubleTileSize) ? 2 : 1;
684-
685-
// Extend register base for large K sizes.
686-
if (numSubtilesPerTile == 2)
687-
registerBase.push_back({0, threadsPerSubtileK}); // Second subtile
688-
689-
extendRegisterBaseForKDim(kTileSize, numSubtilesPerTile);
698+
const int threadsPerSubtileNonK = 16 / kWidthTransRead;
699+
const int threadsPerSubtileK = kWidthTransRead;
690700

691-
// Extend lane base based on MFMA size.
692-
std::vector<std::vector<int32_t>> laneBaseExt;
701+
// Populate lane base for first subtile
702+
for (int i = 1; i < threadsPerSubtileNonK; i *= 2) {
703+
laneBase.push_back({i * kWidthTransRead, 0});
704+
}
705+
for (int i = 1; i < threadsPerSubtileK; i *= 2) {
706+
laneBase.push_back({0, i});
707+
}
693708

694-
if (isMfma32) {
695-
laneBaseExt = {{16, 0}, {0, numSubtilesPerTile * threadsPerSubtileK}};
696-
} else {
697-
laneBaseExt = {{0, numSubtilesPerTile * threadsPerSubtileK},
698-
{0, 2 * numSubtilesPerTile * threadsPerSubtileK}};
699-
}
709+
// Function to extend register base for multiple tiles K dim.
710+
auto extendRegisterBaseForKDim = [&](int kTileSize,
711+
int numSubtilesPerTile) {
712+
const int regsPerTile = kWidthTransRead * numSubtilesPerTile;
713+
int totalRegs = (kSize / kTileSize) * regsPerTile;
714+
715+
for (int reg = regsPerTile; reg < totalRegs; reg *= 2) {
716+
registerBase.push_back({0, (reg / regsPerTile) * kTileSize});
717+
}
718+
};
719+
720+
// kDoubleTileSize is the k dimension of a tile when double rated
721+
// mfma instructions are used.
722+
const int kDoubleTileSize =
723+
isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
724+
// kTileSize is the actually k dimention of a tile, which is
725+
// determined by kWidthDot.
726+
const int kTileSize = kWidthDot * 64 / mDim;
727+
// We use kDoubleTileSize as a reference to check whether the given
728+
// kWidthDot leads to double or single sub-tiles in each tile.
729+
const int numSubtilesPerTile = (kTileSize == kDoubleTileSize) ? 2 : 1;
730+
731+
// Extend register base for large K sizes.
732+
if (numSubtilesPerTile == 2)
733+
registerBase.push_back({0, threadsPerSubtileK}); // Second subtile
734+
735+
extendRegisterBaseForKDim(kTileSize, numSubtilesPerTile);
736+
737+
// Extend lane base based on MFMA size.
738+
std::vector<std::vector<int32_t>> laneBaseExt;
739+
740+
if (isMfma32) {
741+
laneBaseExt = {{16, 0}, {0, numSubtilesPerTile * threadsPerSubtileK}};
742+
} else {
743+
laneBaseExt = {{0, numSubtilesPerTile * threadsPerSubtileK},
744+
{0, 2 * numSubtilesPerTile * threadsPerSubtileK}};
745+
}
746+
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
747+
};
700748

701-
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
749+
if (isFP4)
750+
populateFP4LL(kSize, mDim);
751+
else
752+
populateLL(elemBitWidth, kSize, kWidthDot, mDim);
702753

703754
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
704755
// To assign them to actual matrix dimensions we associate with register
705756
// `order` which is also [nonk, k] given we set kContig to false.
706757
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
707758
{outDimNames[order[0]], outDimNames[order[1]]});
708-
709759
if (hasBatchDim) {
710760
assert(order[2] == 0);
711761
// Extend the base vector with one value to accommodate for the batch

test/Conversion/amd/ds_transpose.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,31 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
378378
ttg.local_store %3, %arg2 : tensor<128x128xf32, #mma32> -> !ttg.memdesc<128x128xf32, #shared1, #smem, mutable>
379379
tt.return
380380
}
381+
382+
// CHECK-LABEL: ds_transpose_t_fp4_mfma32_small
383+
tt.func @ds_transpose_t_fp4_mfma32_small(%arg0: !ttg.memdesc<16x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x16xi8, #shared1, #smem, mutable>) {
384+
// CHECK-COUNT-4: rocdl.ds.read.tr4.b64 %{{.*}} : <3> -> vector<2xi32>
385+
// CHECK-NOT: rocdl.ds.read.tr4.b64
386+
%1 = amdgpu.local_load_packed_tranposed %arg0 : !ttg.memdesc<16x64xi8, #shared, #smem, mutable> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>>
387+
%2 = amdgpu.local_load_packed_tranposed %arg1 : !ttg.memdesc<64x16xi8, #shared1, #smem, mutable> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>>
388+
tt.return
389+
}
390+
391+
// CHECK-LABEL: ds_transpose_t_fp4_mfma16
392+
tt.func @ds_transpose_t_fp4_mfma16(%arg0: !ttg.memdesc<8x128xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<128x8xi8, #shared1, #smem, mutable>) {
393+
// CHECK-COUNT-4: rocdl.ds.read.tr4.b64 %{{.*}} : <3> -> vector<2xi32>
394+
// CHECK-NOT: rocdl.ds.read.tr4.b64
395+
%1 = amdgpu.local_load_packed_tranposed %arg0 : !ttg.memdesc<8x128xi8, #shared, #smem, mutable> -> tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 16}>>
396+
%2 = amdgpu.local_load_packed_tranposed %arg1 : !ttg.memdesc<128x8xi8, #shared1, #smem, mutable> -> tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 16}>>
397+
tt.return
398+
}
399+
400+
// CHECK-LABEL: ds_transpose_t_fp4_mfma32
401+
tt.func @ds_transpose_t_fp4_mfma32(%arg0: !ttg.memdesc<256x256xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<256x256xi8, #shared1, #smem, mutable>) {
402+
// CHECK-COUNT-128: rocdl.ds.read.tr4.b64 %{{.*}} : <3> -> vector<2xi32>
403+
// CHECK-NOT: rocdl.ds.read.tr4.b64
404+
%1 = amdgpu.local_load_packed_tranposed %arg0 : !ttg.memdesc<256x256xi8, #shared, #smem, mutable> -> tensor<512x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>>
405+
%2 = amdgpu.local_load_packed_tranposed %arg1 : !ttg.memdesc<256x256xi8, #shared1, #smem, mutable> -> tensor<128x512xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>>
406+
tt.return
407+
}
381408
}

test/TritonGPU/amd/invalid.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,37 @@ module attributes {"ttg.target" = "hip:gfx942", "ttg.num-ctas" = 1 : i32, "ttg.n
9393
tt.return
9494
}
9595
}
96+
97+
// -----
98+
99+
#mma32 = #ttg.amd_mfma<{version = 4, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
100+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
101+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
102+
#smem = #ttg.shared_memory
103+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
104+
105+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
106+
tt.func @local_load_packed_tranposed_wrong_op_idx(%arg0: !ttg.memdesc<16x64xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x16xi8, #shared1, #smem, mutable>) {
107+
// expected-error @+1 {{Order of dimensions don't match expected}}
108+
%1 = amdgpu.local_load_packed_tranposed %arg0 : !ttg.memdesc<16x64xi8, #shared, #smem, mutable> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>>
109+
tt.return
110+
}
111+
112+
tt.func @local_load_packed_tranposed_wrong_op_idx2(%arg0: !ttg.memdesc<64x16xi8, #shared, #smem, mutable>) {
113+
// expected-error @+1 {{Input and output dimensions don't match after packing changes}}
114+
%1 = amdgpu.local_load_packed_tranposed %arg0 : !ttg.memdesc<64x16xi8, #shared, #smem, mutable> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>>
115+
tt.return
116+
}
117+
tt.func @local_load_packed_tranposed_wrong_attr(%arg1: !ttg.memdesc<128x8xi8, #blocked, #smem, mutable>) {
118+
// expected-error @+1 {{only works with SwizzledSharedEncodingAttr src encoding}}
119+
%1 = amdgpu.local_load_packed_tranposed %arg1 : !ttg.memdesc<128x8xi8, #blocked, #smem, mutable> -> tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>>
120+
tt.return
121+
}
122+
// CHECK-LABEL: ds_transpose_t_fp4_mfma16
123+
tt.func @local_load_packed_tranposed_wrong_shape(%arg0: !ttg.memdesc<8x128xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<128x8xi8, #shared1, #smem, mutable>) {
124+
// expected-error @+1 {{only works with DotOperandEncodingAttr dst encoding}}
125+
%1 = amdgpu.local_load_packed_tranposed %arg0 : !ttg.memdesc<8x128xi8, #shared, #smem, mutable> -> tensor<256x128xi32, #blocked>
126+
tt.return
127+
}
128+
129+
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,4 +522,32 @@ def InThreadTransposeOp : TT_AMDGPU_Op<"in_thread_transpose", [Pure]> {
522522
}];
523523
}
524524

525+
//===----------------------------------------------------------------------===//
526+
// LocalLoadPackedTransposedOp
527+
//===----------------------------------------------------------------------===//
528+
529+
def LocalLoadPackedTransposedOp : TT_AMDGPU_Op<"local_load_packed_tranposed"> {
530+
let summary = "Load a transposed packed tensor from shared memory into a distributed tensor";
531+
let description = [{
532+
Requires a M/N packed and M/N contiguous tensor in shared memory and will yield a K packed K contiguous tensor in registers.
533+
The packing change will change the shape of the tensor by doubling the M/N dimension and halving the K dimension.
534+
For example if A is 16x64 in shared memory, the result of this operation will be 32x32.
535+
}];
536+
let arguments = (ins
537+
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src,
538+
Optional<TTG_AsyncToken>:$token
539+
);
540+
let results = (outs TT_Tensor:$result);
541+
542+
let builders = [
543+
OpBuilder<(ins "Type":$retType, "Value":$src),
544+
[{
545+
build($_builder, $_state, retType, src, /*token=*/static_cast<mlir::Value>(nullptr));
546+
}]>];
547+
548+
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
549+
let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}];
550+
let hasVerifier = 1;
551+
}
552+
525553
#endif

0 commit comments

Comments
 (0)