Skip to content

Commit 14d7bcc

Browse files
authored
[AMD] Rework MFMA intrinsic mapping queries (#5937)
This commit reworks how we encode MFMA intrinsics and query accordingly. Now we use the (version, mDim, nDim, kDim, aElemType, bElemType) as the key, and the value is a vector only containing tuples of (symbol, kDim, kBase). This allows us to drop using 0 as the kDim in the key for older generations, and avoid data duplication in the map. Along the way, fixed the fp8 types for gfx942 and gfx950: gfx942 uses AMD variants, while gfx950 uses OCP ones.
1 parent 4f30282 commit 14d7bcc

File tree

6 files changed

+392
-527
lines changed

6 files changed

+392
-527
lines changed

python/test/unit/language/test_core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6328,7 +6328,8 @@ def matmul_kernel( #
63286328
@pytest.mark.interpreter
63296329
@pytest.mark.parametrize("M, N, K", [(128, 256, 256)])
63306330
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)])
6331-
@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15'])
6331+
@pytest.mark.parametrize(
6332+
"in_type_str", ['float8e5', 'float8e5b16', 'float8e4b8'] if is_hip() else ['float8e5', 'float8e4nv', 'float8e4b15'])
63326333
@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128])
63336334
def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device):
63346335
num_stages = 3
@@ -6338,8 +6339,8 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
63386339
pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90")
63396340
elif is_hip():
63406341
num_stages = 2
6341-
if in_type_str != 'float8e5':
6342-
pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.')
6342+
if in_type_str in ("float8e5b16", "float8e4b8") and not is_hip_mi300():
6343+
pytest.skip(f"{in_type_str} only supported on mi300")
63436344

63446345
check_type_supported(in_type_str, device)
63456346
A = numpy_random((M, K), dtype_str=in_type_str)

test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
77
tt.func public @mfma_dot_fp8e5m2(
88
%arg0: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
99
%arg1: tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
10-
%arg2: tensor<128x256x!tt.ptr<f32>, #blocked> ) {
10+
%arg2: tensor<128x256x!tt.ptr<f32>, #blocked>) {
1111
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
1212
// CHECK: %[[A0:.+]] = ttg.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
1313
// CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>

third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h

Lines changed: 29 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,40 @@
11
#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_MFMAGROUP_H_
22
#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_MFMAGROUP_H_
33

4-
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
5-
#include "llvm/ADT/DenseMap.h"
6-
#include "llvm/ADT/SmallString.h"
4+
#include "mlir/IR/Types.h"
5+
#include "llvm/ADT/StringRef.h"
76

87
namespace mlir {
98

10-
//===----------------------------------------------------------------------===//
11-
// AMDGPU MFMA instruction selection utilities
12-
//===----------------------------------------------------------------------===//
13-
14-
enum class MfmaTypeId : uint32_t {
15-
Fp32TyId = 0,
16-
Xf32TyId,
17-
Fp16TyId,
18-
Bf16TyId,
19-
I8TyId,
20-
Fp8Fp8TyId,
21-
Fp8Bf8TyId,
22-
Bf8Fp8TyId,
23-
Bf8Bf8TyId,
24-
F8F6F4TyId,
25-
};
26-
27-
struct MfmaInsnGroupSelectKey {
28-
unsigned mDim, nDim, kDim;
29-
MfmaTypeId elemType;
30-
int mfmaVersion;
31-
};
32-
33-
struct MfmaInsnAttr {
34-
// m,n,k refer to the shapes of the two operands of mfma instructions.
35-
// Operand A has shape m x k. Operand B has shape k x n.
36-
// For mfma32 and mfma16 instructions, they are the same as
37-
// the dims in the instruction name, i.e. mfma_DType_mxnxkxABType
38-
unsigned m;
39-
unsigned n;
40-
unsigned k;
41-
// kBase refers to the number of elements per thread
9+
struct MfmaIntrinsic {
10+
// Chooses a suitable mfma instrinsic for the given input case.
11+
static FailureOr<MfmaIntrinsic> selectFor(int version, unsigned mDim,
12+
unsigned nDim, unsigned inputKDim,
13+
Type aElemType, Type bElemType,
14+
bool withScale, bool useTF32);
15+
16+
MfmaIntrinsic(StringRef symbol, unsigned m, unsigned n, unsigned k,
17+
unsigned kB, Type aET, Type bET)
18+
: name(symbol), mDim(m), nDim(n), kDim(k), kBase(kB), aElementType(aET),
19+
bElementType(bET) {}
20+
MfmaIntrinsic(const MfmaIntrinsic &other) = default;
21+
MfmaIntrinsic(MfmaIntrinsic &&other) = default;
22+
23+
llvm::StringRef name;
24+
25+
// m, n, and k refer to the shapes of the two operands of an mfma intrinsic:
26+
// Operand A has shape [m]x[k]; operand B has shape [k]x[n].
27+
// For mfma32 and mfma16 intrinsics, they are encoded in the instruction
28+
// name, i.e. mfma_DType_[m]x[n]x[k]xABType.
29+
unsigned mDim;
30+
unsigned nDim;
31+
unsigned kDim;
32+
33+
// kBase is the number of elements each thread holds.
4234
unsigned kBase;
43-
llvm::StringRef insn;
44-
};
45-
46-
template <typename T>
47-
constexpr typename std::underlying_type<T>::type cast_as_underlying(T t) {
48-
return static_cast<typename std::underlying_type<T>::type>(t);
49-
}
50-
51-
struct MfmaInsnGroupSelectKeyInfo
52-
: public llvm::DenseMapInfo<MfmaInsnGroupSelectKey> {
53-
static inline MfmaInsnGroupSelectKey getEmptyKey() {
54-
return {32, 32, 0, MfmaTypeId::Fp32TyId, 0};
55-
}
56-
57-
static inline MfmaInsnGroupSelectKey getTombstoneKey() {
58-
return {32, 32, 0, MfmaTypeId::Fp32TyId, -1};
59-
}
60-
61-
static inline bool isEqual(const MfmaInsnGroupSelectKey &lhs,
62-
const MfmaInsnGroupSelectKey &rhs) {
63-
return lhs.mDim == rhs.mDim && lhs.nDim == rhs.nDim &&
64-
lhs.kDim == rhs.kDim && lhs.elemType == rhs.elemType &&
65-
lhs.mfmaVersion == rhs.mfmaVersion;
66-
}
67-
68-
static unsigned getHashValue(const MfmaInsnGroupSelectKey &key) {
69-
auto dimHash = llvm::detail::combineHashValue(key.mDim, key.nDim);
70-
dimHash = llvm::detail::combineHashValue(dimHash, key.kDim);
71-
auto verHash = llvm::detail::combineHashValue(dimHash, key.mfmaVersion);
72-
auto elemHash = cast_as_underlying(key.elemType);
73-
return llvm::detail::combineHashValue(elemHash, verHash);
74-
}
75-
};
76-
77-
class MfmaInsn {
78-
private:
79-
Type elementTypeA;
80-
Type elementTypeB;
81-
MfmaInsnAttr attr;
8235

83-
public:
84-
static FailureOr<MfmaInsn> selectMfma(unsigned mDim, unsigned nDim,
85-
unsigned kDim, Type elementTypeA,
86-
Type elementTypeB, int mfmaVersion,
87-
bool allowXF32);
88-
MfmaInsn(Type elementTypeA, Type elementTypeB, const MfmaInsnAttr &attr);
89-
unsigned getKDim();
90-
unsigned getMDim();
91-
unsigned getNDim();
92-
StringRef getInsnName();
93-
unsigned getKBase();
94-
Type getElementTypeA();
95-
Type getElementTypeB();
36+
Type aElementType;
37+
Type bElementType;
9638
};
9739
} // namespace mlir
9840

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ using ::mlir::LLVM::AMD::shuffleXor;
3737
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
3838
using ::mlir::triton::gpu::DotOperandEncodingAttr;
3939
using ::mlir::triton::gpu::LinearEncodingAttr;
40-
using ::mlir::triton::gpu::SwizzledSharedEncodingAttr;
4140

4241
using ValueTable = std::map<std::array<int, 3>, Value>;
4342

@@ -75,12 +74,12 @@ struct DotOpMFMAConversionHelper {
7574
: mfmaLayout(mfmaLayout), rewriter(rewriter),
7675
typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {}
7776

78-
Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB,
77+
Value generateMFMAOp(StringRef intrinsicName, Value valA, Value valB,
7978
Value valC) const {
8079
auto b = TritonLLVMOpBuilder(loc, rewriter);
8180
auto resType = valC.getType();
8281
Value zeroFlag = b.i32_val(0);
83-
OperationState loweredOp(loc, mfmaInsnName);
82+
OperationState loweredOp(loc, intrinsicName);
8483
loweredOp.addTypes(resType);
8584
loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
8685
return rewriter.create(loweredOp)->getResult(0);
@@ -228,14 +227,15 @@ struct DotOpMFMAConversionHelper {
228227

229228
template <typename T>
230229
void packAndReplaceResult(T &op, SmallVector<Value> &fc,
231-
FailureOr<MfmaInsn> maybeMfmaInsn, Type dstElemTy,
232-
Type elemtTy, size_t mmaCount) const {
230+
const FailureOr<MfmaIntrinsic> &maybeMfmaIntrinsic,
231+
Type dstElemTy, Type elemtTy,
232+
size_t mmaCount) const {
233233
Type structTy = LLVM::LLVMStructType::getLiteral(
234234
ctx, SmallVector<Type>(fc.size(), dstElemTy));
235235
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);
236236

237-
setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(),
238-
maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(),
237+
setNumGeneratedMMAs(op, mmaCount, maybeMfmaIntrinsic->mDim,
238+
maybeMfmaIntrinsic->nDim, maybeMfmaIntrinsic->kDim,
239239
elemtTy);
240240

241241
rewriter.replaceOp(op, res);
@@ -267,14 +267,15 @@ struct DotOpMFMAConversionHelper {
267267

268268
bool allowXF32 =
269269
op.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
270-
StringRef mfmaInsnName;
271-
auto maybeMfmaInsn = MfmaInsn::selectMfma(
272-
mDim, nDim, kDimOperandSize, elemTyA, elemTyB, mfmaVersion, allowXF32);
273-
if (failed(maybeMfmaInsn))
270+
StringRef intrinsicName;
271+
FailureOr<MfmaIntrinsic> maybeMfmaIntrinsic = MfmaIntrinsic::selectFor(
272+
mfmaVersion, mDim, nDim, kDimOperandSize, elemTyA, elemTyB,
273+
/*withScale=*/false, allowXF32);
274+
if (failed(maybeMfmaIntrinsic))
274275
llvm::report_fatal_error("No match found in MFMA database\n");
275276

276-
mfmaInsnName = maybeMfmaInsn->getInsnName();
277-
unsigned kBase = maybeMfmaInsn->getKBase();
277+
intrinsicName = maybeMfmaIntrinsic->name;
278+
unsigned kBase = maybeMfmaIntrinsic->kBase;
278279

279280
auto aEncoding = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
280281
auto bEncoding = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding());
@@ -301,7 +302,7 @@ struct DotOpMFMAConversionHelper {
301302
auto numRepB = repA[0];
302303
assert(repA[0] == repB[0]);
303304

304-
bool preserveBF16 = mfmaInsnName.contains(".bf16") && mfmaVersion >= 4;
305+
bool preserveBF16 = intrinsicName.contains(".bf16") && mfmaVersion >= 4;
305306
auto operandA = getValuesFromDotOperandLayoutStruct(
306307
loadedA, numRepB, numRepM, numRepK, kWidth, kBase,
307308
aTensorTy.getElementType(), allowXF32, preserveBF16);
@@ -335,12 +336,13 @@ struct DotOpMFMAConversionHelper {
335336
acc = zeroAuxiliarBlocks(subBlocks, acc);
336337
for (int k = 0; k < numRepK; k++) {
337338
for (int kPack = 0; kPack < kWidth / kBase; ++kPack) {
338-
acc =
339-
mfmaLayout.getIsTransposed()
340-
? generateMFMAOp(mfmaInsnName, operandB[kPack][{b, n, k}],
341-
operandA[kPack][{b, m, k}], acc)
342-
: generateMFMAOp(mfmaInsnName, operandA[kPack][{b, m, k}],
343-
operandB[kPack][{b, n, k}], acc);
339+
acc = mfmaLayout.getIsTransposed()
340+
? generateMFMAOp(intrinsicName,
341+
operandB[kPack][{b, n, k}],
342+
operandA[kPack][{b, m, k}], acc)
343+
: generateMFMAOp(intrinsicName,
344+
operandA[kPack][{b, m, k}],
345+
operandB[kPack][{b, n, k}], acc);
344346
if (!firstMfma)
345347
firstMfma = acc;
346348
}
@@ -363,7 +365,8 @@ struct DotOpMFMAConversionHelper {
363365

364366
const size_t mmaCount =
365367
numRepB * numRepM * numRepN * numRepK * kWidth / kBase;
366-
packAndReplaceResult(op, fc, maybeMfmaInsn, dstElemTy, elemTyA, mmaCount);
368+
packAndReplaceResult(op, fc, maybeMfmaIntrinsic, dstElemTy, elemTyA,
369+
mmaCount);
367370

368371
return success();
369372
}
@@ -485,15 +488,15 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
485488
Location loc)
486489
: DotOpMFMAConversionHelper(mfmaLayout, rewriter, typeConverter, loc) {}
487490

488-
Value generateScaledMFMAOp(MfmaInsn &mfmaInsn, Value valA, Value valB,
489-
Value valC, Value valScaleA,
491+
Value generateScaledMFMAOp(const MfmaIntrinsic &mfmaIntrinsic, Value valA,
492+
Value valB, Value valC, Value valScaleA,
490493
Value valScaleB) const {
491494
auto b = TritonLLVMOpBuilder(loc, rewriter);
492495
auto resType = valC.getType();
493496
Value zeroFlag = b.i32_val(0);
494-
OperationState loweredOp(loc, mfmaInsn.getInsnName());
495-
int32_t cbsz = getMfmaF8F6F4MatrixFormat(mfmaInsn.getElementTypeA());
496-
int32_t blgp = getMfmaF8F6F4MatrixFormat(mfmaInsn.getElementTypeB());
497+
OperationState loweredOp(loc, mfmaIntrinsic.name);
498+
int32_t cbsz = getMfmaF8F6F4MatrixFormat(mfmaIntrinsic.aElementType);
499+
int32_t blgp = getMfmaF8F6F4MatrixFormat(mfmaIntrinsic.bElementType);
497500
assert((cbsz != -1) && (blgp != -1));
498501
loweredOp.addTypes(resType);
499502
loweredOp.addOperands({valA, valB, valC, b.i32_val(cbsz), b.i32_val(blgp),
@@ -540,14 +543,16 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
540543

541544
auto ctx = op.getContext();
542545
constexpr bool allowXF32 = false;
543-
auto maybeMfmaInsn = MfmaInsn::selectMfma(
544-
mDim, nDim, kDimOperandSize, scaleDotElemTypeToMLIRType(ctx, aElemType),
545-
scaleDotElemTypeToMLIRType(ctx, bElemType), mfmaVersion, allowXF32);
546-
if (failed(maybeMfmaInsn))
546+
FailureOr<MfmaIntrinsic> maybeMfmaIntrinsic =
547+
MfmaIntrinsic::selectFor(mfmaVersion, mDim, nDim, kDimOperandSize,
548+
scaleDotElemTypeToMLIRType(ctx, aElemType),
549+
scaleDotElemTypeToMLIRType(ctx, bElemType),
550+
/*withScale=*/false, allowXF32);
551+
if (failed(maybeMfmaIntrinsic))
547552
llvm::report_fatal_error("No match found in MFMA database\n");
548553

549-
StringRef mfmaInsnName = maybeMfmaInsn->getInsnName();
550-
unsigned kBase = maybeMfmaInsn->getKBase();
554+
StringRef intrinsicName = maybeMfmaIntrinsic->name;
555+
unsigned kBase = maybeMfmaIntrinsic->kBase;
551556
// Two fp4 are packed into an uint8.
552557
if (aElemType == ScaleDotElemType::E2M1) {
553558
kBase /= 2;
@@ -629,12 +634,12 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
629634
for (int k = 0; k < numRepK; k++) {
630635
for (int kPack = 0; kPack < kWidth / kBase; ++kPack) {
631636
acc = mfmaLayout.getIsTransposed()
632-
? generateScaledMFMAOp(maybeMfmaInsn.value(),
637+
? generateScaledMFMAOp(maybeMfmaIntrinsic.value(),
633638
operandB[kPack][{b, n, k}],
634639
operandA[kPack][{b, m, k}], acc,
635640
operandBScale[kPack][{b, n, k}],
636641
operandAScale[kPack][{b, m, k}])
637-
: generateScaledMFMAOp(maybeMfmaInsn.value(),
642+
: generateScaledMFMAOp(maybeMfmaIntrinsic.value(),
638643
operandA[kPack][{b, m, k}],
639644
operandB[kPack][{b, n, k}], acc,
640645
operandAScale[kPack][{b, m, k}],
@@ -661,7 +666,8 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
661666

662667
const size_t mmaCount =
663668
numRepB * numRepM * numRepN * numRepK * kWidth / kBase;
664-
packAndReplaceResult(op, fc, maybeMfmaInsn, dstElemTy, elemTyA, mmaCount);
669+
packAndReplaceResult(op, fc, maybeMfmaIntrinsic, dstElemTy, elemTyA,
670+
mmaCount);
665671

666672
return success();
667673
}

0 commit comments

Comments
 (0)