Skip to content

Commit 1053fca

Browse files
PMylonantiagainst
andauthored
[AMD] Emulate fp8 *UZ variants on non-gfx942 architectures (#7401)
This commit emulates fp8 *UZ variants for non-gfx942 architectures. This makes it easier to support workload targeting gfx942 on other generations. --------- Co-authored-by: Lei Zhang <[email protected]>
1 parent 77ba5d7 commit 1053fca

File tree

6 files changed

+426
-193
lines changed

6 files changed

+426
-193
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton.language as tl
88
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
99
import traceback
10-
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna4
1111

1212

1313
def format_exception(type, value, tb):
@@ -364,17 +364,21 @@ def test_fp8_support(fresh_triton_cache, dtype):
364364
if cc >= (8, 9):
365365
supported_dtypes.append(tl.float8e4nv)
366366
elif is_hip():
367-
supported_dtypes.append(tl.float8e4nv)
368-
if is_hip_cdna3():
369-
supported_dtypes += [tl.float8e4b8, tl.float8e5b16]
367+
supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16]
368+
if is_hip_cdna4():
369+
warning_dtypes += [tl.float8e4b8, tl.float8e5b16]
370370

371371
@triton.jit
372372
def dtype_kernel(dtype: tl.constexpr):
373373
a = tl.full((64, 64), 0.0, dtype)
374374
tl.dot(a, a)
375375

376376
if dtype in warning_dtypes:
377-
ctx = pytest.warns(UserWarning, match=r"the use of fp8e4b15 is deprecated on Hopper and later architectures")
377+
if is_cuda():
378+
ctx = pytest.warns(UserWarning,
379+
match=r"the use of fp8e4b15 is deprecated on Hopper and later architectures")
380+
elif is_hip_cdna4():
381+
ctx = pytest.warns(UserWarning, match=r"AMD gfx942 specific and not supported on gfx950")
378382
elif dtype in supported_dtypes:
379383
ctx = contextlib.nullcontext()
380384
else:

python/test/unit/language/test_conversions.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton
88
import triton.language as tl
99

10-
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4
1111

1212

1313
def matching_int(dtype):
@@ -265,6 +265,7 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia
265265
('float8e4nv', 'float32'),
266266
267267
('float8e4b8', 'float32'),
268+
('float8e4b8', 'bfloat16'),
268269
('float8e4b8', 'float16'),
269270
270271
('float8e5b16', 'float32'),
@@ -284,12 +285,13 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device):
284285
elif is_hip():
285286
if (src_dtype == 'float8e4nv' and not (is_hip_cdna3() or is_hip_cdna4())):
286287
pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture")
287-
if (src_dtype in ('float8e4b15') or
288-
(src_dtype in ('float8e4b8', 'float8e5b16') and not is_hip_cdna3())):
288+
if src_dtype == 'float8e4b15':
289289
# If the dtype should error out in the given device, we assert that and return
290290
with pytest.raises(triton.CompilationError, match="not supported in this architecture"):
291291
launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device)
292292
return
293+
if src_dtype in ('float8e4b8', 'float8e5b16') and is_hip_cdna2():
294+
pytest.skip(f"{src_dtype} is not supported on AMDGPU CDNA2")
293295

294296
# dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr)
295297
stuff = {
@@ -341,8 +343,8 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
341343
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")
342344

343345
if is_hip():
344-
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and not is_hip_cdna3():
345-
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")
346+
if dst_dtype in ('float8e4b8', 'float8e5b16') and is_hip_cdna2():
347+
pytest.skip(f"{dst_dtype} is not supported on AMDGPU CDNA2")
346348

347349
# dtype : (exponent_bits, mantissa_bits, exponent_bias)
348350
stuff = {

python/triton/language/semantic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,18 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti
14871487
lhs = self.cast(lhs, tl.float16)
14881488
rhs = self.cast(rhs, tl.float16)
14891489

1490+
uses_fp8e4b8 = lhs.dtype.is_fp8e4b8() or rhs.dtype.is_fp8e4b8()
1491+
uses_fp8e5b16 = lhs.dtype.is_fp8e5b16() or rhs.dtype.is_fp8e5b16()
1492+
if uses_fp8e4b8 or uses_fp8e5b16:
1493+
type_name = "fp8e4b8" if uses_fp8e4b8 else "fp8e5b16"
1494+
if type_name in self.builder.options.deprecated_fp8_dot_operand_dtypes:
1495+
arch = self.builder.options.arch
1496+
warnings.warn(
1497+
f"{type_name} is AMD gfx942 specific and not supported on {arch} so it's upcasted to fp16 and can cause significant slow down. "
1498+
f"Please use OCP fp8 variants on {arch} for performance")
1499+
lhs = self.cast(lhs, tl.float16)
1500+
rhs = self.cast(rhs, tl.float16)
1501+
14901502
if input_precision is None:
14911503
input_precision = self.builder.options.default_dot_input_precision
14921504

test/Conversion/amd/fp_to_fp.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
9090
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
9191
tt.func @downcast_to_bf8(%arg0: tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
9292
// GFX942-COUNT-4: rocdl.cvt.pk.bf8.f32
93-
// GFX950-COUNT-4: llvm.fptrunc %{{.+}} : vector<2xf32> to vector<2xf16>
93+
// GFX950-COUNT-16: llvm.trunc %{{.+}} : i32 to i8
9494
%6 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E5M2FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
9595
tt.return
9696
}
@@ -103,7 +103,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
103103
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
104104
tt.func @f32_to_f8(%arg0: tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
105105
// GFX942-COUNT-4: rocdl.cvt.pk.fp8.f32
106-
// GFX950-COUNT-4: llvm.fptrunc %{{.+}} : vector<2xf32> to vector<2xf16>
106+
// GFX950-COUNT-16: llvm.trunc %{{.+}} : i32 to i8
107107
%7 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
108108
tt.return
109109
}

third_party/amd/backend/compiler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ class HIPOptions:
3737
debug: bool = False
3838
sanitize_overflow: bool = True
3939
arch: str = None
40-
# We have native support for OCP fp8 variants since CNDA4/RDNA4. For earlier generations,
40+
# We have native support for OCP fp8 variants since CDNA4/RDNA4. For earlier generations,
4141
# we software emulate the support for them.
42-
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5")
42+
# UZ fp8 variants (fp8e4b8 and fp8e5b16) are natively supported for CDNA3. For other
43+
# architectures they are software emulated.
44+
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e5b16", "fp8e4b8")
4345
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
4446
default_dot_input_precision: str = "ieee"
4547
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
@@ -109,11 +111,12 @@ def parse_options(self, opts) -> Any:
109111
args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))
110112

111113
if "supported_fp8_dtypes" not in opts:
112-
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
113-
if self.target.arch == 'gfx942':
114-
# CDNA3/gfx942 has native support for AMD specific FP8 types.
115-
supported_fp8_dtypes.update({'fp8e4b8', 'fp8e5b16'})
116-
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
114+
args["supported_fp8_dtypes"] = tuple(sorted(HIPOptions.supported_fp8_dtypes))
115+
116+
if self.target.arch == 'gfx950':
117+
deprecated_fp8_dot_operand_dtypes = set(HIPOptions.deprecated_fp8_dot_operand_dtypes)
118+
deprecated_fp8_dot_operand_dtypes.update({"fp8e5b16", "fp8e4b8"})
119+
args["deprecated_fp8_dot_operand_dtypes"] = tuple(sorted(deprecated_fp8_dot_operand_dtypes))
117120

118121
if "enable_fp_fusion" not in opts:
119122
args["enable_fp_fusion"] = knobs.language.default_fp_fusion

0 commit comments

Comments
 (0)