Skip to content

Commit 6415039

Browse files
authored
[AMD] Support 4x64 and 64x4 MFMA layout for dot (#7576)
Support 4x64 and 64x4 MFMA layout for dot, where `D(4x64) = A(4x64) * B(64x64) + C(4x64)` or `D(64x4) = A(64x64) * B(64x4) + C(64x4)`. This allows to use small M/N size GEMM with better performance. Under the hood it uses a group of 16 4x4 mfma instructions. Other related changes: - Refactor mfma input and output linear layout implementation. - Remove unsupported mfma instructions: 1) do not support 4x4 tile as it requires manual reduction in the end, which has a bad performance; 2) do not support 4x16 and 16x4 tile for fp64 - mfma fp64 does not support broadcast for operand A.
1 parent 8a5862d commit 6415039

File tree

9 files changed

+256
-252
lines changed

9 files changed

+256
-252
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,9 +1421,12 @@ LogicalResult AMDMfmaEncodingAttr::verify(
14211421
if (!(version >= 0 && version <= 4)) {
14221422
return emitError() << "version must be in the [0, 4] range";
14231423
}
1424-
if (!((mDim == 32 && nDim == 32) || (mDim == 16 && nDim == 16))) {
1425-
return emitError()
1426-
<< "(M, N) cases other than (32, 32) or (16, 16) unimplemented";
1424+
1425+
const std::array<std::pair<unsigned, unsigned>, 4> validDims = {
1426+
{{32, 32}, {16, 16}, {64, 4}, {4, 64}}};
1427+
if (!llvm::is_contained(validDims, std::make_pair(mDim, nDim))) {
1428+
return emitError() << "invalid (mDim, nDim) combination: (" << mDim << ", "
1429+
<< nDim << ")";
14271430
}
14281431
if (elementType && !(elementType->isF64() || elementType->isF32() ||
14291432
elementType->isInteger(32))) {
@@ -1929,13 +1932,12 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
19291932
unsigned nDim = getNDim();
19301933
assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) ||
19311934
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
1935+
19321936
constexpr int warpSize = 64; // MFMA is always based on the 64-wide warps.
1933-
int kGroups = -1;
1934-
if (mDim == nDim)
1935-
kGroups = warpSize / mDim;
1936-
if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64)
1937-
kGroups = 1;
1937+
int kGroups = warpSize / std::min(mDim, nDim); // for 64x4 and 4x64,
1938+
// kGroups = 16
19381939
int64_t kDim = kWidth * kGroups;
1940+
19391941
if (opIdx == 0)
19401942
return {mDim, kDim};
19411943
else

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 91 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
396396
int nIndex = 1 + hasBatchDim;
397397
(void)mIndex, (void)nIndex;
398398

399-
assert(((getMDim() == 32 && getNDim() == 32) ||
400-
(getMDim() == 16 && getNDim() == 16)) &&
401-
"Unsupported mfma type");
402-
403399
MLIRContext *ctx = getContext();
404400
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
405401

@@ -411,85 +407,56 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
411407
// register and lane layout for mfma instructions.
412408

413409
// We use the order from fastest varying to slowest varying. So each base
414-
// vector is a tuple of values mapping to matrix C's (N, M[, B]) indices.
410+
// vector is a tuple of values mapping to matrix C's (N, M[, B]) indices,
411+
// which will be [1, 0] / [2, 1, 0].
415412
SmallVector<unsigned> order = getDefaultMmaOrder(*this);
416-
auto tileLayout = LinearLayout::empty();
417-
418-
if (getMDim() == 32) {
419-
// For mfma with 32x32 output, each of the 64 threads holds 16 elements.
420-
//
421-
// For the register (i.e., element) dimension, these 16 elements are along
422-
// the matrix C's M dimension, with 4 consecutive elements spanning 4 rows
423-
// and then the next 4 rows being a gap.
424-
//
425-
// For the lane (i.e., thread) dimension, these threads are along the
426-
// matrix C's N dimension, with 32 consecutive threads covering a whole
427-
// row and the next 32 threads start after a gap spanning 4 rows.
428-
tileLayout = LinearLayout(
429-
{{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}},
430-
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}},
431-
{outDimNames[order[0]], outDimNames[order[1]]});
432-
// For mfma.transposed layout, the element ownership among threads are
433-
// "transposed" within each warp.
434-
if (getIsTransposed())
435-
tileLayout = LinearLayout(
436-
{{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}},
437-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}},
438-
{outDimNames[order[0]], outDimNames[order[1]]});
413+
auto dimM = outDimNames[order[1]];
414+
auto dimN = outDimNames[order[0]];
415+
416+
unsigned mDim = getMDim();
417+
unsigned nDim = getNDim();
418+
auto elementType = getElementType();
419+
int height = (elementType && elementType->isF64()) ? 1 : 4;
420+
constexpr int warpSize = 64;
421+
422+
bool isTransposed = getIsTransposed();
423+
// Special case for 64x4 mfma: we always transpose the output to turn
424+
// the 64x4 mfma into a equalvalent 4x64 mfma and swap operand A and B, so
425+
// that we can use the mfma broadcast.
426+
if (mDim == 64 && nDim == 4)
427+
assert(isTransposed && "64x4 mfma must be transposed");
428+
429+
int tiles = (mDim * nDim) / (warpSize * height);
430+
431+
LinearLayout tileLayout = LinearLayout::empty();
432+
if (!isTransposed) {
433+
// Each lane holds 'height' elements along the M dimension.
434+
LinearLayout regs = LinearLayout::identity1D(height, kRegister, dimM);
435+
// First, distribute the lanes along the N dimension.
436+
// Then, distribute the lanes along the M dimension. If the #elements
437+
// exceeds the mDim, duplicate elements across lanes - this can happen for
438+
// 4x4 output.
439+
LinearLayout lanes = LinearLayout::identity1D(nDim, kLane, dimN) *
440+
LinearLayout::identity1D(warpSize / nDim, kLane, dimM);
441+
tileLayout = (regs * lanes);
442+
443+
// Repeat the above distribution along the M dimension to fits the tile.
444+
if (tiles > 0)
445+
tileLayout *= LinearLayout::identity1D(tiles, kRegister, dimM);
439446
} else {
440-
assert(getMDim() == 16);
441-
auto elementType = getElementType();
442-
if (!(elementType && elementType->isF64())) {
443-
// For mfma with 16x16 output (<= 32 bits), each of the 64 threads holds 4
444-
// elements.
445-
//
446-
// For the register (i.e., element) dimension, these 4 elements are along
447-
// the matrix C's M dimension, with 4 consecutive elements spanning 4
448-
// rows.
449-
//
450-
// For the lane (i.e., thread) dimension, these threads are along the
451-
// matrix C's N dimension, with 16 consecutive threads covering a whole
452-
// row and the next 16 threads start after a gap spanning 4 rows.
453-
tileLayout = LinearLayout(
454-
{{kRegister, {{0, 1}, {0, 2}}},
455-
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}},
456-
{outDimNames[order[0]], outDimNames[order[1]]});
457-
// For mfma.transposed layout, the element ownership among threads are
458-
// "transposed" within each warp.
459-
if (getIsTransposed())
460-
tileLayout = LinearLayout(
461-
{{kRegister, {{1, 0}, {2, 0}}},
462-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
463-
{outDimNames[order[0]], outDimNames[order[1]]});
447+
// For the transposed output, we will use the same method for layout but
448+
// swap the order of the M and N dimensions.
449+
LinearLayout regs = LinearLayout::identity1D(height, kRegister, dimN);
450+
LinearLayout lanes = LinearLayout::identity1D(mDim, kLane, dimM) *
451+
LinearLayout::identity1D(warpSize / mDim, kLane, dimN);
452+
tileLayout = (regs * lanes);
464453

465-
} else {
466-
// For 64 bit mfma with 16x16 output, each of the 64 threads holds 4
467-
// elements across 8 VGPRs. each 64 bit element is split across pairs of 2
468-
// VGPRs each. The first VGPR holds the first 32 bits and second holding
469-
// the last 32 bits.
470-
//
471-
// For the register (i.e., element) dimension, these 4 elements are along
472-
// the matrix C's M dimension, with 4 consecutive elements spanning 4
473-
// rows.
474-
//
475-
// For the lane (i.e., thread) dimension, these threads are along the
476-
// matrix C's N dimension, with each group of 16 consecutive threads
477-
// covering a whole adjacent row. Unlike the <=32 bit cases, there's no
478-
// row gaps between the groups.
479-
tileLayout = LinearLayout(
480-
{{kRegister, {{0, 4}, {0, 8}}},
481-
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 1}, {0, 2}}}},
482-
{outDimNames[order[0]], outDimNames[order[1]]});
483-
// For mfma.transposed layout, the element ownership among threads are
484-
// "transposed" within each warp.
485-
if (getIsTransposed())
486-
tileLayout = LinearLayout(
487-
{{kRegister, {{4, 0}, {8, 0}}},
488-
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}},
489-
{outDimNames[order[0]], outDimNames[order[1]]});
490-
}
454+
if (tiles > 0)
455+
tileLayout *= LinearLayout::identity1D(tiles, kRegister, dimN);
491456
}
492457

458+
tileLayout = tileLayout.transposeOuts({dimN, dimM});
459+
493460
// Instead of defining the layout on a CTA tile and using the
494461
// combineCtaCgaWithShape function to extend it to the whole tensor, we take a
495462
// different approach. Suppose tilesPerWarp is 2x2—meaning a warp computes a
@@ -511,8 +478,6 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
511478
const unsigned tilesPerWarpN = tilesPerWarp[nIndex];
512479
const unsigned warpsPerCTAM = warpsPerCTA[mIndex];
513480
const unsigned warpsPerCTAN = warpsPerCTA[nIndex];
514-
const auto &dimN = outDimNames[order[0]];
515-
const auto &dimM = outDimNames[order[1]];
516481

517482
// First, extend the layout along the N dimension:
518483
// - registers are distributed across tilesPerWarpN
@@ -766,11 +731,19 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
766731
int mIndex = 0 + hasBatchDim;
767732

768733
int32_t kWidth = dotMfmaLayout.getKWidth();
769-
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
770-
int32_t kSize = shape[kDim];
734+
auto kDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
735+
771736
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
772737
auto tilesPerWarp = mfmaLayout.getTilesPerWarp();
773-
auto tilePerWarpNonK = tilesPerWarp[kDim];
738+
auto tilePerWarpNonK = tilesPerWarp[kDimIndex];
739+
740+
auto mDim = mfmaLayout.getMDim();
741+
auto nDim = mfmaLayout.getNDim();
742+
auto opIdx = dotMfmaLayout.getOpIdx();
743+
auto nonKDim = opIdx == 0 ? mDim : nDim;
744+
constexpr int warpSize = 64;
745+
746+
int32_t kSize = shape[kDimIndex];
774747

775748
MLIRContext *ctx = dotMfmaLayout.getContext();
776749
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
@@ -785,57 +758,43 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
785758
// for both cases it is [k, nonk]/[k, nonk, batch]
786759
auto order =
787760
getOrderForDotOperand(dotMfmaLayout.getOpIdx(), rank, /*kContig*/ true);
761+
auto dimK = outDimNames[order[0]];
762+
auto dimNonK = outDimNames[order[1]];
788763

789764
// warp order
790765
// common for both operand A and B: [0, 1] / [0, 1, 2]
791766
// in both cases it is [M dim, N dim]/[batch, M dim, N dim]
792767
auto warpOrder = getDefaultMmaOrder(mfmaLayout);
793768

794-
// Lane holds kWidth consecutive elements along k dimension, so
795-
// base register vectors for one tile are initialized in following way:
796-
// {1, 0}, {2, 0} ... {kWidth/2, 0}
797-
std::vector<std::vector<int32_t>> registerBase;
798-
for (int32_t elem = 1; elem < kWidth; elem *= 2)
799-
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
800-
801-
std::vector<std::vector<int32_t>> laneBase;
802-
int32_t kTileSize = -1;
803-
804-
if (mfmaLayout.getMDim() == 32) {
805-
// Canonical MFMA linear layout handles 4 consecutive elements along
806-
// the register dimension. Dot operand handles variable kWidth consecutive
807-
// elements. For lane dim, since the MFMA thread arrangement is {K, N} = {2,
808-
// 32}, this means that mapping of first 5 base (up to thread 16) vectors
809-
// will be an identity along N dim. Thread 32 will be mapped to element
810-
// kWidth in K dimension.
811-
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {kWidth, 0}};
812-
kTileSize = kWidth * 2;
813-
} else {
814-
assert(mfmaLayout.getMDim() == 16);
815-
// For lane dim, since the MFMA thread arrangement is {K, N} = {4, 16}, this
816-
// means that mapping of first 4 base (up to thread 16) vectors will be an
817-
// identity along N dim. Thread 16 will be mapped to element kWisth in K
818-
// dimension. Thread 32 is mapped to element 2*kWidth in K dim.
819-
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {kWidth, 0}, {kWidth * 2, 0}};
820-
kTileSize = kWidth * 4;
821-
}
822-
assert(kTileSize != -1);
823-
// Add repeats of registers along K dimension to register base vectors
824-
for (int32_t elem = kTileSize; elem < kSize; elem *= 2)
825-
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
769+
// Each lane holds kWidth elements along the K dimension
770+
LinearLayout regs = LinearLayout::identity1D(kWidth, kRegister, dimK);
771+
// First distribute nonKDim elements along the non-K dimension,
772+
// then distribute remaining elements along the K dimension
773+
LinearLayout lanes =
774+
LinearLayout::identity1D(nonKDim, kLane, dimNonK) *
775+
LinearLayout::identity1D(warpSize / nonKDim, kLane, dimK);
776+
LinearLayout tileLayout = regs * lanes;
826777

827-
// Base vectors above are defined in a fixed order [k-dim, non-k-dim].
828-
// To assign them to actual matrix dimensions we assoicate with register
829-
// `order` which is also also [k, nonk].
778+
int kTileSize = warpSize / nonKDim * kWidth;
779+
// Special case for 4x64 and 64x4 mfma: for the 64x64 operand,
780+
// we need to repeat the layout 16 times along the K dimension
781+
if ((mDim == 64 && nDim == 4 && opIdx == 0) ||
782+
(mDim == 4 && nDim == 64 && opIdx == 1)) {
783+
tileLayout *= LinearLayout::identity1D(16, kRegister, dimK);
784+
kTileSize *= 16;
785+
}
830786

831-
// Add repeats of registers along non-K dimension to register base vectors
832-
for (int32_t elem = mfmaLayout.getMDim();
833-
elem < tilePerWarpNonK * mfmaLayout.getMDim(); elem *= 2)
834-
registerBase.emplace_back(std::vector<int32_t>{0, elem});
787+
// If shape K is larger than the tile size, repeat the tile
788+
// along the K dimension.
789+
if (kSize > kTileSize) {
790+
tileLayout *= LinearLayout::identity1D(kSize / kTileSize, kRegister, dimK);
791+
}
835792

836-
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
837-
{outDimNames[order[0]], outDimNames[order[1]]});
793+
// Follow the tiles per warp property, repeat the tile layout
794+
// along the non-K dimension.
795+
tileLayout *= LinearLayout::identity1D(tilePerWarpNonK, kRegister, dimNonK);
838796

797+
tileLayout = tileLayout.transposeOuts({dimK, dimNonK});
839798
if (hasBatchDim) {
840799
assert(order[2] == 0);
841800
// Extend the base vector with one value to accommodate for the batch
@@ -845,11 +804,14 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
845804
}
846805

847806
LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder);
848-
849-
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
850-
warpLayout.transposeOuts(outDimNames);
851-
852-
return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
807+
LinearLayout ctaLayout = tileLayout * warpLayout;
808+
809+
// Note the current the output order is [k, nonk]/[k, nonk, batch]. If the
810+
// layout's out-size is smaller than the shape, we follow this order to
811+
// extend each dimension to match the shape. After that, we can transpose
812+
// to match the standard output order.
813+
return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape)
814+
.transposeOuts(outDimNames);
853815
}
854816

855817
LinearLayout

python/test/unit/language/test_core.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3833,9 +3833,11 @@ def get_test_dot_small_k_mfma_cases():
38333833

38343834
# M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size
38353835
# introduced in #4516
3836-
def get_test_dot_small_mn_fma_cases():
3836+
def get_test_dot_small_mn_mfma_cases():
3837+
if not is_hip_cdna():
3838+
return []
38373839
return [(*shape_nw, False, False, epilogue, 'ieee', in_dtype, out_dtype, 1, None)
3838-
for shape_nw in [(2, 2, 16, 1), (1, 64, 64, 1), (64, 2, 64, 2), (64, 64, 4, 4)]
3840+
for shape_nw in [(4, 64, 64, 1), (64, 4, 64, 1)]
38393841
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
38403842
for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]]
38413843

@@ -3875,7 +3877,7 @@ def get_test_small_dots_cases():
38753877
get_test_dot_mfma_edge_cases() + \
38763878
get_test_dot_fp8_output_cases() + \
38773879
get_test_dot_small_k_mfma_cases() + \
3878-
get_test_dot_small_mn_fma_cases() + \
3880+
get_test_dot_small_mn_mfma_cases() + \
38793881
get_test_dot_softmax() + \
38803882
get_test_small_dots_cases())
38813883
@pytest.mark.parametrize("num_ctas", num_ctas_list)
@@ -4070,13 +4072,15 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
40704072
return
40714073

40724074
if is_hip_cdna():
4073-
if M != 4:
4074-
return
40754075
amdgcn = pgm.asm['amdgcn']
4076-
if in_dtype == 'float16':
4077-
assert 'v_dot2c_f32_f16' in amdgcn
4078-
elif (in_dtype == 'bfloat16') and is_hip_cdna4():
4079-
assert 'v_dot2c_f32_bf16' in amdgcn
4076+
4077+
if (M, N) == (4, 64) or (M, N) == (64, 4):
4078+
assert 'v_mfma_f32_4x4' in amdgcn
4079+
elif (M, N) == (4, 32):
4080+
if in_dtype == 'float16':
4081+
assert 'v_dot2c_f32_f16' in amdgcn
4082+
elif (in_dtype == 'bfloat16') and is_hip_cdna4():
4083+
assert 'v_dot2c_f32_bf16' in amdgcn
40804084
return
40814085

40824086
# make sure ld/st are vectorized

0 commit comments

Comments
 (0)