Skip to content

Commit d8e39c3

Browse files
Fixed FP to FP8 downcast (#4748)
Fixes #4630 This implementation is based on an arithmetic approach (see below) and provides an RTNE conversion from F32/16 to F8 types. Despite it is using more expensive fp binary operations, I don't see any performance degradation (you can find a simple bench commented out in the test). The implementation is very simple and template-based. It replaces 4 existing functions that are more complicated and adds a direct conversion from F32 to F8 without intermediate F16. The implementation is based on the following approach. To convert a float from src to dst, we need to find such an exponent and mantissa that src is equal or close to dst. For the normal numbers (exp != 0) it means: ```python 2^(srcExp - SrcBias) * (1 + srcMan/2^SrcMBits) = 2^(dstExp - DstBias) * (1 + dstMan/2^DstMBits) ``` and the following for subnormals (dstExp == 0): ```python src = 2^(1 - DstBias) * (dstMan/2^DstMBits) ``` The exponent is calculated as: ```python dstExp = max(0, srcExp - SrcBias + DstBias) ``` Simplifying the first formula, we can get the following: ```python dstMan = srcMan * 2^(DstMBits - SrcMBits) ``` If `SrcBias == DstBias` (in the case of FP16 to F8E5 conversion), this formula also works for subnormals. In the general case, we can use the following formula for subnormals: ```python dstMan = src * 2^(DstMBits + DstBias - 1) ``` Thus, to get the mantissa, we are using a simple multiplication by a constant and rounding the result to the nearest int. In case of the mantissa overflow, we need to reset it to zero and increment the exponent. It could be done in the following way: ```python dst = max((dstExp << DstMBits) + dstMan, DST_MAX) ``` It gives us the required dst value without the sign bit.
1 parent 4bedc47 commit d8e39c3

File tree

10 files changed

+133
-306
lines changed

10 files changed

+133
-306
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
import pytest
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def type_convert(src, dst, rounding: tl.constexpr, BLOCK_SIZE: tl.constexpr):
9+
idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
10+
x = tl.load(src + idxs)
11+
y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding)
12+
tl.store(dst + idxs, y)
13+
14+
15+
@pytest.mark.parametrize("dst_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=["float8_e4m3fn", "float8_e5m2"])
16+
@pytest.mark.parametrize("src_dtype", [torch.float16, torch.bfloat16, torch.float32],
17+
ids=["float16", "bfloat16", "float32"])
18+
def test_convert_to_fp8(src_dtype, dst_dtype, device):
19+
src_idtype = torch.int32 if src_dtype == torch.float32 else torch.int16
20+
finfo = torch.finfo(dst_dtype)
21+
min_val = torch.tensor(finfo.min, dtype=dst_dtype).view(torch.uint8).item()
22+
max_val = torch.tensor(finfo.max, dtype=dst_dtype).view(torch.uint8).item()
23+
SIZE = 2**16
24+
BLOCK_SIZE = SIZE // 32
25+
src = torch.arange(0, SIZE, dtype=src_idtype, device=device)
26+
if src_dtype == torch.float32:
27+
src = src << 16 | src
28+
src = src.view(src_dtype)
29+
dst = torch.empty_like(src, dtype=dst_dtype, device=device)
30+
type_convert[(SIZE // BLOCK_SIZE, )](triton.reinterpret(src, src_dtype), triton.reinterpret(dst, dst_dtype), 'rtne',
31+
BLOCK_SIZE)
32+
33+
dst = dst.view(torch.uint8)
34+
expect = src.to(dtype=dst_dtype).view(torch.uint8)
35+
diff_mask = dst != expect
36+
src = src[diff_mask]
37+
dst = dst[diff_mask]
38+
expect = expect[diff_mask]
39+
40+
for s, si, e, d in zip(src, src.view(src_idtype), expect.view(torch.uint8), dst.view(torch.uint8)):
41+
if torch.isnan(s):
42+
e = 0b01111111
43+
elif torch.isposinf(s) or (s >= 57344.) or (s >= 464. and dst_dtype == torch.float8_e4m3fn):
44+
e = max_val
45+
elif torch.isneginf(s) or (s <= -57344.) or (s <= -464. and dst_dtype == torch.float8_e4m3fn):
46+
e = min_val
47+
elif si == 0b1000000000000000: # -0.0
48+
e = 0b10000000
49+
50+
if d != e:
51+
sfmt = "032b" if src_dtype == torch.float32 else "016b"
52+
dfmt = "08b"
53+
msg = f"Src={s}({format(si, sfmt)}). Expected={format(e, dfmt)}. Actual={format(d, dfmt)}."
54+
pytest.fail(msg)

python/test/unit/language/test_conversions.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,17 +234,9 @@ def downcast_test(src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits,
234234

235235
src = launch_exhaustive_populate(src_dtype, offset << 24, 2**24, False, src_dtype.primitive_bitwidth, max_repr, device)
236236
dst = launch_type_convert_triton(src, src_dtype, dst_dtype, device=device, rounding=rounding)
237-
# Emulated cast always works on fp32. In XPU Triton kernels FP32 is casted to FP8 through FP16, which
238-
# in some cases gives different results compared to direct FP32 to FP8 conversion (some precision might
239-
# be lost due to two-step conversion). To get matching results, we convert FP32 source data to FP16 and
240-
# back to FP32. This will need to be changed back when HW FP32->FP8 convertion is used for XPU.
241-
if device=='xpu' and src_dtype.primitive_bitwidth == 32 and dst_dtype.primitive_bitwidth == 8:
242-
src = launch_type_convert_triton(src, src_dtype, tl.float16, device=device, rounding=rounding)
243-
src = launch_type_convert_triton(src, tl.float16, tl.float32, device=device)
244-
else:
245-
src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device)
237+
src = launch_type_convert_triton(src, src_dtype, tl.float32, device=device)
246238

247-
dst2 = launch_downcast_emulated(src, src_dtype, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device)
239+
dst2 = launch_downcast_emulated(src, tl.float32, dst_dtype, rounding, exponent_bits, mantissa_bits, exponent_bias, device=device)
248240

249241
dst = launch_upcast_emulated(dst, exponent_bits, mantissa_bits, exponent_bias, device=device)
250242
dst2 = launch_upcast_emulated(dst2, exponent_bits, mantissa_bits, exponent_bias, device=device)

scripts/skiplist/a770/language.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -651,20 +651,6 @@ python/test/unit/language/test_matmul.py::test_lhs_in_tmem[float8e5-True-128-64-
651651
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_store
652652
python/test/unit/language/test_tensor_descriptor.py::test_make_tensor_descriptor_matmul
653653
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_rank_reducing_matmul
654-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4630
655-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e4nv-nan]
656-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5-inf]
657-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5--inf]
658-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e4nv-nan]
659-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5-inf]
660-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5--inf]
661-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-max]
662-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-min]
663-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-inf]
664-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv--inf]
665-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-nan]
666-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5-inf]
667-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5--inf]
668654
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4289
669655
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float16-add]
670656
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float32-add]

scripts/skiplist/arl-h/language.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -524,20 +524,6 @@ python/test/unit/language/test_matmul.py::test_lhs_in_tmem[float8e5-True-128-64-
524524
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_store
525525
python/test/unit/language/test_tensor_descriptor.py::test_make_tensor_descriptor_matmul
526526
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_rank_reducing_matmul
527-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4630
528-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e4nv-nan]
529-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5-inf]
530-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5--inf]
531-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e4nv-nan]
532-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5-inf]
533-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5--inf]
534-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-max]
535-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-min]
536-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-inf]
537-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv--inf]
538-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-nan]
539-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5-inf]
540-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5--inf]
541527
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4289
542528
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float16-add]
543529
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float32-add]

scripts/skiplist/arl-s/language.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -524,20 +524,6 @@ python/test/unit/language/test_matmul.py::test_lhs_in_tmem[float8e5-True-128-64-
524524
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_store
525525
python/test/unit/language/test_tensor_descriptor.py::test_make_tensor_descriptor_matmul
526526
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_rank_reducing_matmul
527-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4630
528-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e4nv-nan]
529-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5-inf]
530-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5--inf]
531-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e4nv-nan]
532-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5-inf]
533-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5--inf]
534-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-max]
535-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-min]
536-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-inf]
537-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv--inf]
538-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-nan]
539-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5-inf]
540-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5--inf]
541527
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4289
542528
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float16-add]
543529
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float32-add]

scripts/skiplist/default/language.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,6 @@
11
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4665
22
python/test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float64-float64]
33
python/test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float64-float64]
4-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4630
5-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e4nv-nan]
6-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5-inf]
7-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5--inf]
8-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e4nv-nan]
9-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5-inf]
10-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5--inf]
11-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-max]
12-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-min]
13-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-inf]
14-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv--inf]
15-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-nan]
16-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5-inf]
17-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5--inf]
184
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4289
195
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float16-add]
206
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float32-add]

scripts/skiplist/lts/language.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -252,20 +252,6 @@ python/test/unit/language/test_core.py::test_convert_mma2mma[mma_pair0-float16-2
252252
python/test/unit/language/test_matmul.py::test_lhs_in_tmem
253253
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_batched_gemm_2d_tma
254254
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_batched_gemm_3d_tma
255-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4630
256-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e4nv-nan]
257-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5-inf]
258-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5--inf]
259-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e4nv-nan]
260-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5-inf]
261-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5--inf]
262-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-max]
263-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-min]
264-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-inf]
265-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv--inf]
266-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-nan]
267-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5-inf]
268-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5--inf]
269255
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4289
270256
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float16-add]
271257
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float32-add]

scripts/skiplist/mtl/language.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -301,20 +301,6 @@ python/test/unit/language/test_core.py::test_dot[1-64-128-128-2-False-False-none
301301
python/test/unit/language/test_core.py::test_dot[1-64-128-128-2-False-False-none-tf32-float32-float32-1-None1]
302302
python/test/unit/language/test_core.py::test_dot[1-128-128-64-4-False-False-chain-dot-ieee-float8e5-float32-1-None]
303303
python/test/unit/language/test_core.py::test_dot[1-128-128-64-4-False-False-chain-dot-ieee-float8e4nv-float32-1-None]
304-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4630
305-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e4nv-nan]
306-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5-inf]
307-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5--inf]
308-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e4nv-nan]
309-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5-inf]
310-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5--inf]
311-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-max]
312-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-min]
313-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-inf]
314-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv--inf]
315-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-nan]
316-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5-inf]
317-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5--inf]
318304
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4289
319305
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float16-add]
320306
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float32-add]

scripts/skiplist/xe2/language.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,3 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4630
2-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e4nv-nan]
3-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5-inf]
4-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float32-float8e5--inf]
5-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e4nv-nan]
6-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5-inf]
7-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[float16-float8e5--inf]
8-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-max]
9-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-min]
10-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-inf]
11-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv--inf]
12-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e4nv-nan]
13-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5-inf]
14-
python/test/unit/language/test_conversions.py::test_typeconvert_downcast_clamping[bfloat16-float8e5--inf]
151
# https://github.com/intel/intel-xpu-backend-for-triton/issues/4289
162
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float16-add]
173
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[1-1024-host-1-float32-add]

0 commit comments

Comments
 (0)