|
7 | 7 | import triton
|
8 | 8 | import triton.language as tl
|
9 | 9 |
|
10 |
| -from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4 |
| 10 | +from triton._internal_testing import is_cuda, is_hip, is_hip_cdna2, is_hip_cdna3, is_hip_cdna4 |
11 | 11 |
|
12 | 12 |
|
13 | 13 | def matching_int(dtype):
|
@@ -366,3 +366,74 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
|
366 | 366 |
|
367 | 367 | for i in range(256):
|
368 | 368 | downcast_test(getattr(tl, src_dtype), getattr(tl, dst_dtype), rounding, *stuff, max_repr, i, device=device)
|
| 369 | + |
| 370 | +@pytest.mark.parametrize("mode", [ |
| 371 | + 'max', 'min', 'inf', '-inf', 'nan', |
| 372 | +]) |
| 373 | +@pytest.mark.parametrize("dst_dtype", ["float8e4nv", "float8e5"]) |
| 374 | +@pytest.mark.parametrize("src_dtype", ["float32", "float16", "bfloat16"]) |
| 375 | +def test_typeconvert_downcast_clamping(src_dtype, dst_dtype, mode, rounding="rtne", device="cuda"): |
| 376 | + if is_cuda(): |
| 377 | + if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0): |
| 378 | + pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") |
| 379 | + |
| 380 | + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0): |
| 381 | + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") |
| 382 | + elif is_hip(): |
| 383 | + if is_hip_cdna2(): |
| 384 | + pytest.skip(f"{dst_dtype} downcast to {dst_dtype} with clamping is not fully tested on AMDGPU CDNA2") |
| 385 | + |
| 386 | + if is_hip_cdna3(): |
| 387 | + if src_dtype == 'bfloat16' and dst_dtype == 'float8e4nv': |
| 388 | + pytest.skip(f"{src_dtype} downcast to {dst_dtype} with clamping is not fully tested on AMDGPU CDNA3") |
| 389 | + if dst_dtype == 'float8e5' and mode in ('inf', '-inf'): |
| 390 | + pytest.skip(f"Downcast to {dst_dtype} with clamping for `inf` or `-inf` " |
| 391 | + "is not fully tested on AMDGPU CDNA3") |
| 392 | + |
| 393 | + converter = { |
| 394 | + tl.float8e4nv: torch.float8_e4m3fn, |
| 395 | + tl.float8e5: torch.float8_e5m2, |
| 396 | + tl.float16: torch.float16, |
| 397 | + tl.bfloat16: torch.bfloat16, |
| 398 | + tl.float32: torch.float32 |
| 399 | + } |
| 400 | + |
| 401 | + tl_src_dtype = getattr(tl, src_dtype) |
| 402 | + tl_dst_dtype = getattr(tl, dst_dtype) |
| 403 | + |
| 404 | + torch_src_dtype = converter[tl_src_dtype] |
| 405 | + torch_dst_dtype = converter[tl_dst_dtype] |
| 406 | + |
| 407 | + if mode in ('max', 'min'): |
| 408 | + # Added to input to exceed the representation range to produce NaN |
| 409 | + exceed_value = 100.0 |
| 410 | + test_value = torch.finfo(torch_dst_dtype).max + exceed_value |
| 411 | + expected_result = torch.finfo(torch_dst_dtype).max |
| 412 | + elif mode in ('inf', '-inf'): |
| 413 | + test_value = torch.inf |
| 414 | + expected_result = torch.finfo(torch_dst_dtype).max |
| 415 | + else: |
| 416 | + assert mode == 'nan' |
| 417 | + test_value = torch.nan |
| 418 | + expected_result = torch.nan |
| 419 | + |
| 420 | + if mode in ('min', '-inf'): |
| 421 | + test_value *= -1.0 |
| 422 | + expected_result *= -1.0 |
| 423 | + |
| 424 | + BLOCK_SIZE = 1024 |
| 425 | + shape = (BLOCK_SIZE * 2,) |
| 426 | + src = torch.full(shape, test_value, dtype=torch_src_dtype, device=device) |
| 427 | + dst = torch.empty(shape, dtype=torch_dst_dtype, device=device) |
| 428 | + |
| 429 | + type_convert_triton[(src.shape[0] // BLOCK_SIZE,)]( |
| 430 | + triton.reinterpret(src, torch_src_dtype), |
| 431 | + triton.reinterpret(dst, torch_dst_dtype), |
| 432 | + rounding, |
| 433 | + BLOCK_SIZE |
| 434 | + ) |
| 435 | + |
| 436 | + if mode == 'nan': |
| 437 | + assert(torch.all(torch.isnan(dst))) |
| 438 | + else: |
| 439 | + torch.testing.assert_close(dst, torch.full_like(dst, expected_result)) |
0 commit comments