Skip to content

Commit 5389ed7

Browse files
authored
[AMD] Emulate Float8E4M3FN with Float16 on CDNA3 and below (#7186)
The fact that gfx942 has its own FP8 variants, not the OCP ones, is a common pitfall. Also starting gfx950, we switch to OCP FP8 variants. So it means we have a one-generation special case here. This commit enables emulating Float8E4M3FN with FP16 like what we already do for Float8E5M2 for better portability, with a performance remark.
1 parent 19c842c commit 5389ed7

File tree

6 files changed

+57
-40
lines changed

6 files changed

+57
-40
lines changed

test/TritonGPU/amd/accelerate-amd-matmul-mfma.mlir

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,45 @@
1-
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=0" | FileCheck %s --check-prefixes MFMA0,CHECK
2-
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16" | FileCheck %s --check-prefixes MFMA16,CHECK
1+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=0" --verify-diagnostics | FileCheck %s --check-prefixes MFMA0,CHECK
2+
// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16" --verify-diagnostics | FileCheck %s --check-prefixes MFMA16,CHECK
33

44
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
5-
// CHECK-LABEL: mfma_dot_fp8e5m2
5+
// CHECK-LABEL: mfma_dot_fp8e5m2_fp8e4m3fn
66
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
7-
tt.func public @mfma_dot_fp8e5m2(
7+
tt.func public @mfma_dot_fp8e5m2_fp8e4m3fn(
88
%arg0: tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
9-
%arg1: tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
9+
%arg1: tensor<64x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
1010
%arg2: tensor<128x256x!tt.ptr<f32>, #blocked>) {
1111
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
1212
// CHECK: %[[A0:.+]] = ttg.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
1313
// CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
14+
// CHECK: %[[B0:.+]] = ttg.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
15+
// CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E4M3FN, {{.*}} -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
16+
// CHECK: tt.dot %[[A1]], %[[B1]]
17+
// expected-remark @+2 {{missing native support for fp8 variant on current architecture; emulated with fp16 so low performance}}
18+
// expected-remark @+1 {{for gfx942 please use native supported fp8 variants}}
19+
%1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
20+
tt.store %arg2, %1 : tensor<128x256x!tt.ptr<f32>, #blocked>
21+
tt.return
22+
}
23+
}
24+
25+
// -----
26+
27+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [1, 0]}>
28+
// CHECK-LABEL: mfma_dot_fp8e4m3fn_fp8e5m2
29+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
30+
tt.func public @mfma_dot_fp8e4m3fn_fp8e5m2(
31+
%arg0: tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>,
32+
%arg1: tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>,
33+
%arg2: tensor<128x256x!tt.ptr<f32>, #blocked>) {
34+
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked>
35+
// CHECK: %[[A0:.+]] = ttg.convert_layout %arg0 : {{.*}} -> tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
36+
// CHECK: %[[A1:.+]] = tt.fp_to_fp %[[A0]] : {{.*}} -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
1437
// CHECK: %[[B0:.+]] = ttg.convert_layout %arg1 : {{.*}} -> tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
1538
// CHECK: %[[B1:.+]] = tt.fp_to_fp %[[B0]] : tensor<64x256xf8E5M2, {{.*}} -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
1639
// CHECK: tt.dot %[[A1]], %[[B1]]
17-
%1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
40+
// expected-remark @+2 {{missing native support for fp8 variant on current architecture; emulated with fp16 so low performance}}
41+
// expected-remark @+1 {{for gfx942 please use native supported fp8 variants}}
42+
%1 = tt.dot %arg0, %arg1, %cst : tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
1843
tt.store %arg2, %1 : tensor<128x256x!tt.ptr<f32>, #blocked>
1944
tt.return
2045
}

test/TritonGPU/amd/accelerate-amd-matmul-unsupported.mlir

Lines changed: 0 additions & 15 deletions
This file was deleted.

third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ inline bool isF8F6F4(mlir::Type type) {
1414

1515
struct MfmaIntrinsic {
1616
// Chooses a suitable mfma instrinsic for the given input case.
17-
static FailureOr<MfmaIntrinsic> selectFor(int version, unsigned mDim,
18-
unsigned nDim, unsigned inputKDim,
19-
Type aElemType, Type bElemType,
20-
bool withScale, bool useTF32);
17+
static FailureOr<MfmaIntrinsic> selectFor(Location loc, int version,
18+
unsigned mDim, unsigned nDim,
19+
unsigned inputKDim, Type aElemType,
20+
Type bElemType, bool withScale,
21+
bool useTF32);
2122

2223
MfmaIntrinsic(StringRef symbol, unsigned m, unsigned n, unsigned k,
2324
unsigned kB, Type aET, Type bET)

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ struct DotOpMFMAConversionHelper {
272272
op.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
273273
StringRef intrinsicName;
274274
FailureOr<MfmaIntrinsic> maybeMfmaIntrinsic = MfmaIntrinsic::selectFor(
275-
mfmaVersion, mDim, nDim, kDimOperandSize, elemTyA, elemTyB,
275+
op.getLoc(), mfmaVersion, mDim, nDim, kDimOperandSize, elemTyA, elemTyB,
276276
/*withScale=*/false, allowXF32);
277277
if (failed(maybeMfmaIntrinsic))
278278
return op.emitError(
@@ -584,7 +584,7 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
584584
auto ctx = op.getContext();
585585
constexpr bool allowXF32 = false;
586586
FailureOr<MfmaIntrinsic> maybeMfmaIntrinsic = MfmaIntrinsic::selectFor(
587-
mfmaVersion, mDim, nDim,
587+
op.getLoc(), mfmaVersion, mDim, nDim,
588588
aElemType == ScaleDotElemType::E2M1 ? kDimOperandSize * 2
589589
: kDimOperandSize,
590590
scaleDotElemTypeToMLIRType(ctx, aElemType),

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ chooseMfmaInstruction(Location loc, int mfmaVersion, RankedTensorType cType,
166166
return failure();
167167

168168
FailureOr<MfmaIntrinsic> maybeMfmaIntrinsic =
169-
MfmaIntrinsic::selectFor(mfmaVersion, mDim, nDim, inputKSize, aElemType,
170-
bElemType, withScale, allowXF32);
169+
MfmaIntrinsic::selectFor(loc, mfmaVersion, mDim, nDim, inputKSize,
170+
aElemType, bElemType, withScale, allowXF32);
171171
if (failed(maybeMfmaIntrinsic))
172172
return emitError(loc, "no matching matrix core intrinsic due to "
173173
"unsupported element type");

third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "TritonAMDGPUTransforms/MfmaGroup.h"
22
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
33
#include "mlir/IR/BuiltinTypes.h"
4+
#include "mlir/IR/Diagnostics.h"
45
#include "llvm/ADT/DenseMap.h"
56
#include <tuple>
67

@@ -23,9 +24,9 @@ using MfmaKey =
2324
//
2425
// This function adapts certain parameters so we can be flexible when trying
2526
// to query with "mismatches".
26-
MfmaKey composeMfmaKeyFor(unsigned version, unsigned mDim, unsigned nDim,
27-
Type &aElemType, Type &bElemType, bool withScale,
28-
bool useTF32) {
27+
MfmaKey composeMfmaKeyFor(Location loc, unsigned version, unsigned mDim,
28+
unsigned nDim, Type &aElemType, Type &bElemType,
29+
bool withScale, bool useTF32) {
2930
Type aET = aElemType, bET = bElemType;
3031
Builder b(aElemType.getContext());
3132
if (withScale) {
@@ -38,9 +39,14 @@ MfmaKey composeMfmaKeyFor(unsigned version, unsigned mDim, unsigned nDim,
3839
// In the MFMA map we use the proper TF32 type. So "fix" it here.
3940
assert(version == 3);
4041
aET = bET = b.getType<FloatTF32Type>();
41-
} else if (version <= 3 && isa<Float8E5M2Type>(aET) &&
42-
isa<Float8E5M2Type>(bET)) {
43-
// For the OCP FP8 E5M2 type, we can emulate the support for it with FP16.
42+
} else if (version <= 3 && isa<Float8E5M2Type, Float8E4M3FNType>(aET) &&
43+
isa<Float8E5M2Type, Float8E4M3FNType>(bET)) {
44+
emitRemark(loc, "missing native support for fp8 variant on current "
45+
"architecture; emulated with fp16 so low performance");
46+
if (version == 3)
47+
emitRemark(loc, "for gfx942 please use native supported fp8 variants");
48+
// For the OCP FP8 E5M2/E4M3FN type, we don't have native support until
49+
// CDNA4. So emulate with FP16.
4450
aElemType = bElemType = aET = bET = b.getF16Type();
4551
}
4652
return {version, mDim, nDim, aET.getTypeID(), bET.getTypeID()};
@@ -270,12 +276,12 @@ MfmaDatabase::MfmaDatabase(MLIRContext *context) {
270276
//===----------------------------------------------------------------------===//
271277

272278
FailureOr<MfmaIntrinsic>
273-
MfmaIntrinsic::selectFor(int version, unsigned mDim, unsigned nDim,
274-
unsigned inputKDim, Type aElemType, Type bElemType,
275-
bool withScale, bool useTF32) {
279+
MfmaIntrinsic::selectFor(Location loc, int version, unsigned mDim,
280+
unsigned nDim, unsigned inputKDim, Type aElemType,
281+
Type bElemType, bool withScale, bool useTF32) {
276282
const MfmaMap &mfmaMap = MfmaDatabase::get(aElemType.getContext());
277-
MfmaKey key = composeMfmaKeyFor(version, mDim, nDim, aElemType, bElemType,
278-
withScale, useTF32);
283+
MfmaKey key = composeMfmaKeyFor(loc, version, mDim, nDim, aElemType,
284+
bElemType, withScale, useTF32);
279285

280286
auto it = mfmaMap.find(key);
281287
if (it == mfmaMap.end())

0 commit comments

Comments
 (0)