Skip to content

Commit 4b9efc5

Browse files
authored
[AMD][gfx12] Enable dot for f8 operands (#6814)
- Enable f8, bf8 dtype for gfx12 - Add related tests - Fix matmul dtype matcher Signed-off-by: Ilya Veselov <[email protected]>
1 parent 9aa2c86 commit 4b9efc5

File tree

6 files changed

+30
-28
lines changed

6 files changed

+30
-28
lines changed

python/test/unit/language/test_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
is_hip_cdna2,
3636
is_hip_cdna3,
3737
is_hip_cdna4,
38+
is_hip_gfx12,
3839
is_xpu,
3940
get_arch,
4041
torch_float8_dtypes,
@@ -3722,8 +3723,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
37223723
pytest.skip("float8e4nv not supported on sm <= 80")
37233724

37243725
if is_hip():
3725-
if in_dtype in ("float8e5", "float8e4nv") and not is_hip_cdna4():
3726-
pytest.skip(f"{in_dtype} only supported on CDNA4")
3726+
if in_dtype in ("float8e5", "float8e4nv") and not (is_hip_cdna4() or is_hip_gfx12()):
3727+
pytest.skip(f"{in_dtype} only supported on CDNA4 and gfx12")
37273728
if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3():
37283729
pytest.skip(f"{in_dtype} only supported on CDNA3")
37293730
if not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_cdna3())):

python/triton/_internal_testing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ def is_hip_cdna4():
6262
return target is not None and target.backend == 'hip' and target.arch == 'gfx950'
6363

6464

65+
def is_hip_gfx12():
66+
target = get_current_target()
67+
print(target.arch)
68+
return target is not None and target.backend == 'hip' and 'gfx12' in target.arch
69+
70+
6571
def is_hip_cdna():
6672
return is_hip_cdna2() or is_hip_cdna3() or is_hip_cdna4()
6773

test/TritonGPU/amd/accelerate-amd-matmul-wmma-gen2.mlir

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
8080
// CHECK: %[[DOT2_OP_C_EXT:.+]] = arith.extf %[[DOT2_OP_C]]
8181
// CHECK-SAME: to tensor<32x64xf32, #[[WMMA_0]]>
8282
%3 = arith.constant dense<0.000000e+00> : tensor<32x64xf16, #blocked>
83-
// CHECK: %[[DOT2_OP_A_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_A]]
83+
// CHECK: %[[DOT2_OP_A:.+]] = ttg.convert_layout %[[DOT2_ARG_A]]
8484
// CHECK-SAME: -> tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]]
85-
// CHECK: %[[DOT2_OP_A_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_A_F8]]
86-
// CHECK-SAME: -> tensor<32x128xf16, #ttg.dot_op<{opIdx = 0, parent = #[[WMMA_0]], kWidth = 8}>>
87-
// CHECK: %[[DOT2_OP_B_F8:.+]] = ttg.convert_layout %[[DOT2_ARG_B]]
85+
// CHECK: %[[DOT2_OP_B:.+]] = ttg.convert_layout %[[DOT2_ARG_B]]
8886
// CHECK-SAME: -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]]
89-
// CHECK: %[[DOT2_OP_B_F16:.+]] = tt.fp_to_fp %[[DOT2_OP_B_F8]]
90-
// CHECK-SAME: -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #[[WMMA_0]], kWidth = 8}>>
91-
// CHECK: %[[DOT2_WMMA_RES:.+]] = tt.dot %[[DOT2_OP_A_F16]], %[[DOT2_OP_B_F16]], %[[DOT2_OP_C_EXT]]
87+
// CHECK: %[[DOT2_WMMA_RES:.+]] = tt.dot %[[DOT2_OP_A]], %[[DOT2_OP_B]], %[[DOT2_OP_C_EXT]]
9288
// CHECK-SAME: -> tensor<32x64xf32, #[[WMMA_0]]
9389
%4 = tt.dot %0, %1, %3 : tensor<32x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x64xf16, #blocked>
9490
// CHECK: %[[CONVERTED_RES:.+]] = ttg.convert_layout %[[DOT2_WMMA_RES]]

third_party/amd/backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ def parse_options(self, opts) -> Any:
113113
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
114114
elif self.target.arch == 'gfx950':
115115
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
116+
elif 'gfx12' in self.target.arch:
117+
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
116118
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
117119

118120
if "enable_fp_fusion" not in opts:

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ std::string getTypeStr(Type ty) {
172172
scalarName = "iu8";
173173
} else if (ty.isInteger(4)) {
174174
scalarName = "iu4";
175+
} else if (llvm::isa<Float8E4M3FNType>(ty)) {
176+
scalarName = "fp8";
177+
} else if (llvm::isa<Float8E5M2Type>(ty)) {
178+
scalarName = "bf8";
175179
} else if (auto vecTy = dyn_cast<VectorType>(ty)) {
176180
auto elemType = vecTy.getElementType();
177181
auto numElems = vecTy.getNumElements();

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -299,20 +299,18 @@ OperandTypesVector getOperandTypesForWmmaOp(PatternRewriter &rewriter,
299299
// by WMMA instruction, but not supported by triton
300300
// clang-format on
301301
};
302-
// TODO: support fp8 configurations for WMMAv2. The code should be as
303-
// following:
304-
// if (version == 2) {
305-
// Type fp8 = rewriter.getFp8Type();
306-
// Type bf8 = rewriter.getBF8Type();
307-
// applicableTypes.append({
308-
// // clang-format off
309-
// {fp8, fp8, f32, f32},
310-
// {fp8, bf8, f32, f32},
311-
// {bf8, fp8, f32, f32},
312-
// {bf8, bf8, f32, f32},
313-
// // clang-format on
314-
// });
315-
// }
302+
if (version == 2) {
303+
Type fp8e4nv = rewriter.getType<Float8E4M3FNType>();
304+
Type fp8e5 = rewriter.getType<Float8E5M2Type>();
305+
applicableTypes.append({
306+
// clang-format off
307+
{fp8e4nv, fp8e4nv, f32, f32},
308+
{fp8e4nv, fp8e5, f32, f32},
309+
{fp8e5, fp8e4nv, f32, f32},
310+
{fp8e5, fp8e5, f32, f32},
311+
// clang-format on
312+
});
313+
}
316314
return selectMatrixCoreOperandTypes(dot, applicableTypes);
317315
}
318316

@@ -1002,11 +1000,6 @@ class BlockedToWMMA : public OpRewritePattern<tt::DotOp> {
10021000
aShape[rank - 1] % mnkDim[2] != 0) // k
10031001
return failure();
10041002

1005-
if (wmmaVersion == 2 && llvm::isa<FloatType>(oldAType) &&
1006-
oldAType.getIntOrFloatBitWidth() == 8) {
1007-
return rewriter.notifyMatchFailure(dotOp, "not supported yet");
1008-
}
1009-
10101003
// get operand types
10111004
auto operandTypes = getOperandTypesForWmmaOp(rewriter, dotOp, wmmaVersion);
10121005
if (operandTypes.empty())

0 commit comments

Comments
 (0)