Skip to content

Commit 54411be

Browse files
authored
[AMD] Enable more passing fp8 downcast clamping tests (#7363)
With triton-lang/triton#7361, these tests are now passing.
1 parent 20a8ac9 commit 54411be

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

python/test/unit/language/test_conversions.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -376,16 +376,9 @@ def test_typeconvert_downcast_clamping(src_dtype, dst_dtype, mode, rounding="rtn
376376

377377
if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0):
378378
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")
379-
elif is_hip():
380-
if is_hip_cdna2():
381-
pytest.skip(f"{dst_dtype} downcast to {dst_dtype} with clamping is not fully tested on AMDGPU CDNA2")
382-
383-
if is_hip_cdna3():
384-
if src_dtype == 'bfloat16' and dst_dtype == 'float8e4nv':
385-
pytest.skip(f"{src_dtype} downcast to {dst_dtype} with clamping is not fully tested on AMDGPU CDNA3")
386-
if dst_dtype == 'float8e5' and mode in ('inf', '-inf'):
387-
pytest.skip(f"Downcast to {dst_dtype} with clamping for `inf` or `-inf` "
388-
"is not fully tested on AMDGPU CDNA3")
379+
elif is_hip_cdna2() or is_hip_cdna3():
380+
if src_dtype == 'bfloat16' and dst_dtype == 'float8e4nv':
381+
pytest.skip(f"{src_dtype} downcast to {dst_dtype} with clamping is not fully tested on AMDGPU CDNA2/3")
389382

390383
converter = {
391384
tl.float8e4nv: torch.float8_e4m3fn,

0 commit comments

Comments
 (0)