Skip to content

Commit bf5913d

Browse files
authored
[AMD] NFC: Tidy up FP8 variant support cases (#7267)
1 parent 2c59df5 commit bf5913d

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

python/test/unit/language/test_compile_errors.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton.language as tl
88
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
99
import traceback
10-
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4
10+
from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3
1111

1212

1313
def format_exception(type, value, tb):
@@ -364,10 +364,9 @@ def test_fp8_support(fresh_triton_cache, dtype):
364364
if cc >= (8, 9):
365365
supported_dtypes.append(tl.float8e4nv)
366366
elif is_hip():
367+
supported_dtypes.append(tl.float8e4nv)
367368
if is_hip_cdna3():
368-
supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16]
369-
if is_hip_cdna4():
370-
supported_dtypes += [tl.float8e4nv]
369+
supported_dtypes += [tl.float8e4b8, tl.float8e5b16]
371370

372371
@triton.jit
373372
def dtype_kernel(dtype: tl.constexpr):

third_party/amd/backend/compiler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class HIPOptions:
3737
debug: bool = False
3838
sanitize_overflow: bool = True
3939
arch: str = None
40-
supported_fp8_dtypes: Tuple[str] = ("fp8e5", )
40+
# We have native support for OCP fp8 variants since CNDA4/RDNA4. For earlier generations,
41+
# we software emulate the support for them.
42+
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5")
4143
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
4244
default_dot_input_precision: str = "ieee"
4345
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
@@ -113,11 +115,8 @@ def parse_options(self, opts) -> Any:
113115
if "supported_fp8_dtypes" not in opts:
114116
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
115117
if self.target.arch == 'gfx942':
116-
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
117-
elif self.target.arch == 'gfx950':
118-
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
119-
elif 'gfx12' in self.target.arch:
120-
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e5'})
118+
# CDNA3/gfx942 has native support for AMD specific FP8 types.
119+
supported_fp8_dtypes.update({'fp8e4b8', 'fp8e5b16'})
121120
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
122121

123122
if "enable_fp_fusion" not in opts:

0 commit comments

Comments
 (0)