|
| 1 | +import torch |
| 2 | +import pytest |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | + |
| 6 | + |
| 7 | +@triton.jit |
| 8 | +def type_convert(src, dst, rounding: tl.constexpr, BLOCK_SIZE: tl.constexpr): |
| 9 | + idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 10 | + x = tl.load(src + idxs) |
| 11 | + y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) |
| 12 | + tl.store(dst + idxs, y) |
| 13 | + |
| 14 | + |
| 15 | +@pytest.mark.parametrize("dst_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=["float8_e4m3fn", "float8_e5m2"]) |
| 16 | +@pytest.mark.parametrize("src_dtype", [torch.float16, torch.bfloat16, torch.float32], |
| 17 | + ids=["float16", "bfloat16", "float32"]) |
| 18 | +def test_convert_to_fp8(src_dtype, dst_dtype, device): |
| 19 | + src_idtype = torch.int32 if src_dtype == torch.float32 else torch.int16 |
| 20 | + finfo = torch.finfo(dst_dtype) |
| 21 | + min_val = torch.tensor(finfo.min, dtype=dst_dtype).view(torch.uint8).item() |
| 22 | + max_val = torch.tensor(finfo.max, dtype=dst_dtype).view(torch.uint8).item() |
| 23 | + SIZE = 2**16 |
| 24 | + BLOCK_SIZE = SIZE // 32 |
| 25 | + src = torch.arange(0, SIZE, dtype=src_idtype, device=device) |
| 26 | + if src_dtype == torch.float32: |
| 27 | + src = src << 16 | src |
| 28 | + src = src.view(src_dtype) |
| 29 | + dst = torch.empty_like(src, dtype=dst_dtype, device=device) |
| 30 | + type_convert[(SIZE // BLOCK_SIZE, )](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), 'rtne', |
| 31 | + BLOCK_SIZE) |
| 32 | + |
| 33 | + dst = dst.view(torch.uint8) |
| 34 | + expect = src.to(dtype=dst_dtype).view(torch.uint8) |
| 35 | + diff_mask = dst != expect |
| 36 | + src = src[diff_mask] |
| 37 | + dst = dst[diff_mask] |
| 38 | + expect = expect[diff_mask] |
| 39 | + |
| 40 | + for s, si, e, d in zip(src, src.view(src_idtype), expect.view(torch.uint8), dst.view(torch.uint8)): |
| 41 | + if torch.isnan(s): |
| 42 | + e = 0b01111111 |
| 43 | + elif torch.isposinf(s) or (s >= 57344.) or (s >= 464. and dst_dtype == torch.float8_e4m3fn): |
| 44 | + e = max_val |
| 45 | + elif torch.isneginf(s) or (s <= -57344.) or (s <= -464. and dst_dtype == torch.float8_e4m3fn): |
| 46 | + e = min_val |
| 47 | + elif si == 0b1000000000000000: # -0.0 |
| 48 | + e = 0b10000000 |
| 49 | + |
| 50 | + if d != e: |
| 51 | + sfmt = "032b" if src_dtype == torch.float32 else "016b" |
| 52 | + dfmt = "08b" |
| 53 | + msg = f"Src={s}({format(si, sfmt)}). Expected={format(e, dfmt)}. Actual={format(d, dfmt)}." |
| 54 | + pytest.fail(msg) |
0 commit comments