Skip to content

Commit d664a09

Browse files
authored
[AMD] Support gfx950 double rate mfma ops (#5831)
This patch adds support for new MFMA double-rate operations on gfx950. - We prefer to use double-rate ops for gfx950 when the input K size is large enough. - The double rate mfma ops of bf16 datatype must preserve the bf16 type and not convert it into i16, as the llvm backend expects bf16. - kpack is always 1 for gfx950.
1 parent 196a08f commit d664a09

File tree

9 files changed

+272
-116
lines changed

9 files changed

+272
-116
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,7 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
14091409
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
14101410
return {};
14111411
}
1412+
14121413
if (attr.getName() == "isTransposed") {
14131414
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
14141415
return {};

python/test/unit/language/test_core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3420,9 +3420,20 @@ def get_test_dot_small_mn_fma_cases():
34203420
for in_dtype, out_dtype in [('float16', 'float16'), ('float32', 'float32')]]
34213421

34223422

3423+
def get_test_dot_double_rate_cases():
3424+
if not is_hip_cdna():
3425+
return []
3426+
return [(32, 32, 16, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None),
3427+
(32, 32, 16, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None),
3428+
(16, 16, 32, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None),
3429+
(16, 16, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)]
3430+
3431+
34233432
@pytest.mark.interpreter
34243433
@pytest.mark.parametrize(
34253434
"M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size",
3435+
get_test_dot_double_rate_cases() + \
3436+
get_test_dot_base_cases() + \
34263437
get_test_dot_base_cases() + \
34273438
get_test_dot_mixed_sizes_cases() + \
34283439
get_test_dot_transposed_op_base_cases() + \

python/test/unit/language/test_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import triton.tools.experimental_descriptor
77
from test_mxfp import MXFP4Tensor, MXScaleTensor
88
import re
9-
from triton._internal_testing import is_cuda, is_hip, is_hip_mi200, is_hip_mi350, is_hip_cdna
9+
from triton._internal_testing import is_cuda, is_hip, is_hip_mi300, is_hip_mi350, is_hip_cdna
1010

1111

1212
def f8_to_f16(x, dtype):
@@ -84,8 +84,8 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
8484
if is_hip() and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N) * NUM_STAGES * get_src_element_ty_size(dtype_src_str)
8585
> 65536):
8686
pytest.skip("HIP path requires less than 64KB of shared memory")
87-
if is_hip_mi200() and dtype_src_str == "tensorfloat32":
88-
pytest.skip("HIP MI200 does not support tensorfloat32")
87+
if is_hip() and (not is_hip_mi300()) and dtype_src_str == "tensorfloat32":
88+
pytest.skip("tensorfloat32 is only supported on HIP MI300")
8989
if dtype_src_str == "float8e5" and BLOCK_K == 16:
9090
pytest.skip("Skipping cases small K for float8")
9191
if dtype_src_str == "float8e5" and device == "cuda" and torch.cuda.get_device_capability()[0] < 9:
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx950' | FileCheck %s
2+
3+
// CHECK-LABEL:mfma_16x16x32_f16
4+
5+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = false}>
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} {
7+
tt.func public @mfma_16x16x32_f16(%arg0: tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>,
8+
%arg1: tensor<32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) {
9+
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
10+
// CHECK: rocdl.mfma.f32.16x16x32.f16 {{.*}} : (vector<8xf16>, vector<8xf16>
11+
%dot = tt.dot %arg0, %arg1, %cst : tensor<16x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<16x16xf32, #mma>
12+
tt.return
13+
}
14+
}
15+
16+
// -----
17+
18+
// CHECK-LABEL:mfma_16x16x32_bf16
19+
20+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = false}>
21+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} {
22+
tt.func public @mfma_16x16x32_bf16(%arg0: tensor<16x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>,
23+
%arg1: tensor<32x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) {
24+
%cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma>
25+
// CHECK: rocdl.mfma.f32.16x16x32.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>
26+
%dot = tt.dot %arg0, %arg1, %cst : tensor<16x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x16xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<16x16xf32, #mma>
27+
tt.return
28+
}
29+
}
30+
31+
// -----
32+
33+
// CHECK-LABEL:mfma_32x32x16_f16
34+
35+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = false}>
36+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} {
37+
tt.func public @mfma_32x32x16_f16(%arg0: tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>,
38+
%arg1: tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) {
39+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
40+
// CHECK: rocdl.mfma.f32.32x32x16.f16 {{.*}} : (vector<8xf16>, vector<8xf16>
41+
%dot = tt.dot %arg0, %arg1, %cst : tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x32xf32, #mma>
42+
tt.return
43+
}
44+
}
45+
46+
47+
// -----
48+
49+
// CHECK-LABEL:mfma_32x32x16_bf16
50+
51+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = false}>
52+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "tttg.threads-per-warp" = 64 : i32} {
53+
tt.func public @mfma_32x32x16_bf16(%arg0: tensor<32x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>,
54+
%arg1: tensor<16x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>) {
55+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
56+
// CHECK: rocdl.mfma.f32.32x32x16.bf16 {{.*}} : (vector<8xbf16>, vector<8xbf16>
57+
%dot = tt.dot %arg0, %arg1, %cst : tensor<32x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<16x32xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<32x32xf32, #mma>
58+
tt.return
59+
}
60+
}

third_party/amd/backend/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def __post_init__(self):
6565
# Ignore user-defined warp size for gfx9
6666
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
6767
object.__setattr__(self, 'warp_size', warp_size)
68+
# Only kpack=1 is supported on gfx950
69+
kpack = 1 if self.arch == 'gfx950' else self.kpack
70+
object.__setattr__(self, 'kpack', kpack)
6871
libs = ["ocml", "ockl"]
6972
for lib in libs:
7073
extern_libs[lib] = str(default_libdir / f'{lib}.bc')

third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ enum class MfmaTypeId : uint32_t {
2525
};
2626

2727
struct MfmaInsnGroupSelectKey {
28-
unsigned mDim, nDim;
28+
unsigned mDim, nDim, kDim;
2929
MfmaTypeId elemType;
3030
int mfmaVersion;
3131
};
@@ -51,21 +51,23 @@ constexpr typename std::underlying_type<T>::type cast_as_underlying(T t) {
5151
struct MfmaInsnGroupSelectKeyInfo
5252
: public llvm::DenseMapInfo<MfmaInsnGroupSelectKey> {
5353
static inline MfmaInsnGroupSelectKey getEmptyKey() {
54-
return {32, 32, MfmaTypeId::Fp32TyId, 0};
54+
return {32, 32, 0, MfmaTypeId::Fp32TyId, 0};
5555
}
5656

5757
static inline MfmaInsnGroupSelectKey getTombstoneKey() {
58-
return {32, 32, MfmaTypeId::Fp32TyId, -1};
58+
return {32, 32, 0, MfmaTypeId::Fp32TyId, -1};
5959
}
6060

6161
static inline bool isEqual(const MfmaInsnGroupSelectKey &lhs,
6262
const MfmaInsnGroupSelectKey &rhs) {
6363
return lhs.mDim == rhs.mDim && lhs.nDim == rhs.nDim &&
64-
lhs.elemType == rhs.elemType && lhs.mfmaVersion == rhs.mfmaVersion;
64+
lhs.kDim == rhs.kDim && lhs.elemType == rhs.elemType &&
65+
lhs.mfmaVersion == rhs.mfmaVersion;
6566
}
6667

6768
static unsigned getHashValue(const MfmaInsnGroupSelectKey &key) {
6869
auto dimHash = llvm::detail::combineHashValue(key.mDim, key.nDim);
70+
dimHash = llvm::detail::combineHashValue(dimHash, key.kDim);
6971
auto verHash = llvm::detail::combineHashValue(dimHash, key.mfmaVersion);
7072
auto elemHash = cast_as_underlying(key.elemType);
7173
return llvm::detail::combineHashValue(elemHash, verHash);
@@ -80,8 +82,9 @@ class MfmaInsn {
8082

8183
public:
8284
static FailureOr<MfmaInsn> selectMfma(unsigned mDim, unsigned nDim,
83-
Type elementTypeA, Type elementTypeB,
84-
int mfmaVersion, bool allowXF32);
85+
unsigned kDim, Type elementTypeA,
86+
Type elementTypeB, int mfmaVersion,
87+
bool allowXF32);
8588
MfmaInsn(Type elementTypeA, Type elementTypeB, const MfmaInsnAttr &attr);
8689
unsigned getKDim();
8790
unsigned getMDim();

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,13 @@ struct DotOpMFMAConversionHelper {
271271
auto elemTyA = aTensorTy.getElementType();
272272
auto elemTyB = bTensorTy.getElementType();
273273

274+
const auto kDimOperandSize = aTensorTy.getShape().back();
275+
274276
bool allowXF32 =
275277
op.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
276278
StringRef mfmaInsnName;
277-
auto maybeMfmaInsn = MfmaInsn::selectMfma(mDim, nDim, elemTyA, elemTyB,
278-
mfmaVersion, allowXF32);
279+
auto maybeMfmaInsn = MfmaInsn::selectMfma(
280+
mDim, nDim, kDimOperandSize, elemTyA, elemTyB, mfmaVersion, allowXF32);
279281
if (failed(maybeMfmaInsn))
280282
llvm::report_fatal_error("No match found in MFMA database\n");
281283

@@ -290,8 +292,6 @@ struct DotOpMFMAConversionHelper {
290292
if (aTensorTy.getElementType().isF32() && allowXF32)
291293
kWidth *= 2;
292294

293-
auto rank = aTensorTy.getShape().size();
294-
const auto kDimOperandSize = aTensorTy.getShape()[rank - 1];
295295
const auto kDimInstrSize = mfmaLayout.getInstrShapeForOperand(kWidth, 0)[1];
296296

297297
auto repA = mfmaLayout.getRepForOperand(aTensorTy.getShape(), kWidth, 0);
@@ -309,12 +309,13 @@ struct DotOpMFMAConversionHelper {
309309
auto numRepB = repA[0];
310310
assert(repA[0] == repB[0]);
311311

312+
bool preserveBF16 = mfmaInsnName.contains(".bf16") && mfmaVersion >= 4;
312313
auto operandA = getValuesFromDotOperandLayoutStruct(
313314
loadedA, numRepB, numRepM, numRepK, kWidth, kBase,
314-
aTensorTy.getElementType(), allowXF32);
315+
aTensorTy.getElementType(), allowXF32, preserveBF16);
315316
auto operandB = getValuesFromDotOperandLayoutStruct(
316317
loadedB, numRepB, numRepN, numRepK, kWidth, kBase,
317-
aTensorTy.getElementType(), allowXF32);
318+
aTensorTy.getElementType(), allowXF32, preserveBF16);
318319

319320
auto dstElemTy = dTensorTy.getElementType();
320321
auto fc = unpackLLElements(loc, loadedC, rewriter);
@@ -379,19 +380,19 @@ struct DotOpMFMAConversionHelper {
379380
/// rawElems is a vector of kWidth elements. We need to prepare vector(s) of
380381
/// kBase elements for each mfma instruction
381382
SmallVector<Value> extractOperands(Value rawElems, int kWidth, int kBase,
382-
Type type) const {
383+
Type type, bool preserveBF16) const {
383384
auto b = TritonLLVMOpBuilder(loc, rewriter);
384385
int kpack = kWidth / kBase;
385386
SmallVector<Value> results;
386387
auto vecTy = vec_ty(type, kBase);
387-
if (type.isBF16())
388+
if (type.isBF16() && !preserveBF16)
388389
vecTy = vec_ty(i16_ty, kBase);
389390
for (int k = 0; k < kpack; ++k) {
390391
Value vec = b.undef(vecTy);
391392
for (int elemId = 0; elemId < kBase; ++elemId) {
392393
auto val =
393394
b.extract_element(type, rawElems, b.i32_val(elemId + k * kBase));
394-
if (type.isBF16()) {
395+
if (type.isBF16() && !preserveBF16) {
395396
// rocdl.mfma.f32.32x32x8bf16.1k calls for input of i16 type
396397
auto cast = b.bitcast(val, i16_ty);
397398
vec = b.insert_element(vecTy, vec, cast, b.i32_val(elemId));
@@ -423,7 +424,7 @@ struct DotOpMFMAConversionHelper {
423424
virtual SmallVector<ValueTable>
424425
getValuesFromDotOperandLayoutStruct(Value value, int batch, int n0, int n1,
425426
int kWidth, int kBase, Type type,
426-
bool allowXF32) const {
427+
bool allowXF32, bool preserveBF16) const {
427428
auto tb = TritonLLVMOpBuilder(loc, rewriter);
428429
auto elems = unpackLLElements(loc, value, rewriter);
429430
int kpack = kWidth / kBase;
@@ -449,14 +450,18 @@ struct DotOpMFMAConversionHelper {
449450
} else {
450451
SmallVector<Value> vals;
451452
if (type.isF32() && allowXF32) {
452-
vals = extractOperands(rawElems, kWidth, kBase, f32_ty);
453+
vals = extractOperands(rawElems, kWidth, kBase, f32_ty,
454+
preserveBF16);
453455
} else if (type.getIntOrFloatBitWidth() == 8) {
454-
vals = extractOperands(rawElems, kWidth, kBase, i8_ty);
456+
vals =
457+
extractOperands(rawElems, kWidth, kBase, i8_ty, preserveBF16);
455458
} else if (type.isBF16()) {
456-
vals = extractOperands(rawElems, kWidth, kBase, bf16_ty);
459+
vals = extractOperands(rawElems, kWidth, kBase, bf16_ty,
460+
preserveBF16);
457461
} else {
458462
assert(type.isF16() && "Unsupported data type");
459-
vals = extractOperands(rawElems, kWidth, kBase, f16_ty);
463+
vals = extractOperands(rawElems, kWidth, kBase, f16_ty,
464+
preserveBF16);
460465
}
461466
for (int k = 0; k < kpack; ++k) {
462467
dotOpVals[k][{b, i, j}] = vals[k];
@@ -518,6 +523,8 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
518523
ScaleDotElemType aElemType = op.getLhsType();
519524
ScaleDotElemType bElemType = op.getRhsType();
520525

526+
const auto kDimOperandSize = aTensorTy.getShape().back();
527+
521528
auto supportsTypes = [](ScaleDotElemType elemType) {
522529
return elemType == ScaleDotElemType::E2M1;
523530
};
@@ -529,7 +536,7 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
529536
auto ctx = op.getContext();
530537
constexpr bool allowXF32 = false;
531538
auto maybeMfmaInsn = MfmaInsn::selectMfma(
532-
mDim, nDim, scaleDotElemTypeToMLIRType(ctx, aElemType),
539+
mDim, nDim, kDimOperandSize, scaleDotElemTypeToMLIRType(ctx, aElemType),
533540
scaleDotElemTypeToMLIRType(ctx, bElemType), mfmaVersion, allowXF32);
534541
if (failed(maybeMfmaInsn))
535542
llvm::report_fatal_error("No match found in MFMA database\n");
@@ -544,8 +551,6 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
544551
auto aEncoding = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
545552
auto bEncoding = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding());
546553
int kWidth = aEncoding.getKWidth();
547-
auto rank = aTensorTy.getShape().size();
548-
const auto kDimOperandSize = aTensorTy.getShape()[rank - 1];
549554
const auto kDimInstrSize = mfmaLayout.getInstrShapeForOperand(kWidth, 0)[1];
550555

551556
auto repA = mfmaLayout.getRepForOperand(aTensorTy.getShape(), kWidth, 0);
@@ -575,19 +580,19 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
575580

576581
auto operandA = getValuesFromDotOperandLayoutStruct(
577582
loadedA, numRepB, numRepM, numRepK, kWidth, kBase,
578-
aTensorTy.getElementType(), allowXF32);
583+
aTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false);
579584
auto operandB = getValuesFromDotOperandLayoutStruct(
580585
loadedB, numRepB, numRepN, numRepK, kWidth, kBase,
581-
bTensorTy.getElementType(), allowXF32);
586+
bTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false);
582587

583588
// Scales have the same replica distributions as their corresponding
584589
// operands.
585590
auto operandAScale = getValuesFromDotOperandLayoutStruct(
586591
loadedAScale, numRepB, numRepM, numRepK, scaleKWidth, scaleKBase,
587-
aScaleTensorTy.getElementType(), allowXF32);
592+
aScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false);
588593
auto operandBScale = getValuesFromDotOperandLayoutStruct(
589594
loadedBScale, numRepB, numRepN, numRepK, scaleKWidth, scaleKBase,
590-
bScaleTensorTy.getElementType(), allowXF32);
595+
bScaleTensorTy.getElementType(), allowXF32, /*preserveBF16=*/false);
591596

592597
auto dstElemTy = dTensorTy.getElementType();
593598
auto fc = unpackLLElements(loc, loadedC, rewriter);

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ FailureOr<MfmaInsn> chooseMfmaInstruction(RankedTensorType cType,
136136
if (mDim == 0 || nDim == 0)
137137
return failure();
138138

139-
auto maybeMfmaInsn = MfmaInsn::selectMfma(mDim, nDim, aElemType, bElemType,
140-
mfmaVersion, allowXF32);
139+
auto maybeMfmaInsn = MfmaInsn::selectMfma(mDim, nDim, inputKSize, aElemType,
140+
bElemType, mfmaVersion, allowXF32);
141141
if (failed(maybeMfmaInsn))
142142
llvm::report_fatal_error("No match found in MFMA database\n");
143143

@@ -511,7 +511,6 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
511511
Value dotOutput =
512512
convertAndCastTensor(rewriter, newDot, oldRetType.getEncoding(),
513513
oldRetType.getElementType());
514-
515514
rewriter.replaceOp(dotOp, dotOutput);
516515

517516
return success();

0 commit comments

Comments
 (0)