1111
1212import numpy as np
1313
14-
14+ # fmt: off
1515F8E4M3_LUT = np .array (
1616 [
17- 0.0 , 0.001953125 , 0.00390625 , 0.005859375 , 0.0078125 , 0.009765625 , 0.01171875 , 0.013671875 ,
18- 0.015625 , 0.017578125 , 0.01953125 , 0.021484375 , 0.0234375 , 0.025390625 , 0.02734375 , 0.029296875 ,
19- 0.03125 , 0.03515625 , 0.0390625 , 0.04296875 , 0.046875 , 0.05078125 , 0.0546875 , 0.05859375 ,
20- 0.0625 , 0.0703125 , 0.078125 , 0.0859375 , 0.09375 , 0.1015625 , 0.109375 , 0.1171875 ,
21- 0.125 , 0.140625 , 0.15625 , 0.171875 , 0.1875 , 0.203125 , 0.21875 , 0.234375 ,
22- 0.25 , 0.28125 , 0.3125 , 0.34375 , 0.375 , 0.40625 , 0.4375 , 0.46875 ,
23- 0.5 , 0.5625 , 0.625 , 0.6875 , 0.75 , 0.8125 , 0.875 , 0.9375 ,
24- 1.0 , 1.125 , 1.25 , 1.375 , 1.5 , 1.625 , 1.75 , 1.875 ,
25- 2.0 , 2.25 , 2.5 , 2.75 , 3.0 , 3.25 , 3.5 , 3.75 ,
26- 4.0 , 4.5 , 5.0 , 5.5 , 6.0 , 6.5 , 7.0 , 7.5 ,
27- 8.0 , 9.0 , 10.0 , 11.0 , 12.0 , 13.0 , 14.0 , 15.0 ,
28- 16.0 , 18.0 , 20.0 , 22.0 , 24.0 , 26.0 , 28.0 , 30.0 ,
29- 32.0 , 36.0 , 40.0 , 44.0 , 48.0 , 52.0 , 56.0 , 60.0 ,
30- 64.0 , 72.0 , 80.0 , 88.0 , 96.0 , 104.0 , 112.0 , 120.0 ,
31- 128.0 , 144.0 , 160.0 , 176.0 , 192.0 , 208.0 , 224.0 , 240.0 ,
32- 256.0 , 288.0 , 320.0 , 352.0 , 384.0 , 416.0 , 448.0 , np .nan ,
17+ 0.0 , 0.001953125 , 0.00390625 , 0.005859375 , 0.0078125 , 0.009765625 , 0.01171875 , 0.013671875 , # noqa
18+ 0.015625 , 0.017578125 , 0.01953125 , 0.021484375 , 0.0234375 , 0.025390625 , 0.02734375 , 0.029296875 , # noqa
19+ 0.03125 , 0.03515625 , 0.0390625 , 0.04296875 , 0.046875 , 0.05078125 , 0.0546875 , 0.05859375 , # noqa
20+ 0.0625 , 0.0703125 , 0.078125 , 0.0859375 , 0.09375 , 0.1015625 , 0.109375 , 0.1171875 , # noqa
21+ 0.125 , 0.140625 , 0.15625 , 0.171875 , 0.1875 , 0.203125 , 0.21875 , 0.234375 , # noqa
22+ 0.25 , 0.28125 , 0.3125 , 0.34375 , 0.375 , 0.40625 , 0.4375 , 0.46875 , # noqa
23+ 0.5 , 0.5625 , 0.625 , 0.6875 , 0.75 , 0.8125 , 0.875 , 0.9375 , # noqa
24+ 1.0 , 1.125 , 1.25 , 1.375 , 1.5 , 1.625 , 1.75 , 1.875 , # noqa
25+ 2.0 , 2.25 , 2.5 , 2.75 , 3.0 , 3.25 , 3.5 , 3.75 , # noqa
26+ 4.0 , 4.5 , 5.0 , 5.5 , 6.0 , 6.5 , 7.0 , 7.5 , # noqa
27+ 8.0 , 9.0 , 10.0 , 11.0 , 12.0 , 13.0 , 14.0 , 15.0 , # noqa
28+ 16.0 , 18.0 , 20.0 , 22.0 , 24.0 , 26.0 , 28.0 , 30.0 , # noqa
29+ 32.0 , 36.0 , 40.0 , 44.0 , 48.0 , 52.0 , 56.0 , 60.0 , # noqa
30+ 64.0 , 72.0 , 80.0 , 88.0 , 96.0 , 104.0 , 112.0 , 120.0 , # noqa
31+ 128.0 , 144.0 , 160.0 , 176.0 , 192.0 , 208.0 , 224.0 , 240.0 , # noqa
32+ 256.0 , 288.0 , 320.0 , 352.0 , 384.0 , 416.0 , 448.0 , np .nan , # noqa
3333 ],
3434 dtype = np .float32 ,
3535)
36+ # fmt: on
3637
3738
3839def _f16_to_f8e4m3_bits_scalar (h_bits : int ) -> int :
@@ -46,7 +47,6 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
4647 f16_m_size = 10
4748
4849 # f8 e4m3 layout
49- f8e4m3_s_mask = 0x80
5050 f8e4m3_e_size = 4
5151 f8e4m3_e_mask = 0x78
5252 f8e4m3_e_bias = 7
@@ -57,9 +57,9 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
5757 byte_shift = 8
5858
5959 # f8 masks in uint16 domain
60- f8_e_mask = f8e4m3_e_mask << byte_shift # 0x7800
61- f8_m_mask = f8e4m3_m_mask << byte_shift # 0x0700
62- f8_m_hidden_one_mask = 0x0800 # hidden 1 for subnormals
60+ f8_e_mask = f8e4m3_e_mask << byte_shift # 0x7800
61+ f8_m_mask = f8e4m3_m_mask << byte_shift # 0x0700
62+ f8_m_hidden_one_mask = 0x0800 # hidden 1 for subnormals
6363
6464 # rounding constants (same as C++)
6565 round_half = 0x01FF
@@ -79,7 +79,7 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
7979
8080 if f16_e_field == f16_e_mask :
8181 # f16 NaN / Inf -> f8 NaN (no Inf)
82- f8_bits |= ( f8e4m3_e_mask | f8e4m3_m_mask )
82+ f8_bits |= f8e4m3_e_mask | f8e4m3_m_mask
8383 elif f16_e_field != 0 :
8484 # normalized f16
8585 f8_biased_exp = (f16_e_field >> f16_m_size ) - (f16_e_bias - f8e4m3_e_bias )
@@ -97,12 +97,12 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
9797 # now set exponent & mantissa
9898 if f8_biased_exp > f8e4m3_e_max :
9999 # overflow -> NaN (no Inf)
100- f8_bits |= ( f8e4m3_e_mask | f8e4m3_m_mask )
100+ f8_bits |= f8e4m3_e_mask | f8e4m3_m_mask
101101 elif f8_biased_exp > 0 :
102102 # normalized f8
103103 exp_field = (f8_biased_exp & (f8e4m3_e_mask >> f8e4m3_m_size )) << f8e4m3_m_size
104104 f8_bits |= exp_field
105- f8_bits |= ( fractional >> byte_shift )
105+ f8_bits |= fractional >> byte_shift
106106 else :
107107 # subnormal f8
108108 fractional = f8_m_hidden_one_mask | ((inp & f16_m_mask ) << (f16_e_size - f8e4m3_e_size ))
@@ -113,11 +113,14 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
113113
114114 fractional = 0 if f8_exp < f8_e_subnormal_min else (fractional >> (1 - f8_biased_exp ))
115115
116- if (((fractional & round_half ) == round_odd and sticky == 0 ) or
117- (fractional & round_norm ) != 0 or sticky != 0 ):
116+ if (
117+ ((fractional & round_half ) == round_odd and sticky == 0 )
118+ or (fractional & round_norm ) != 0
119+ or sticky != 0
120+ ):
118121 fractional += round_even
119122
120- f8_bits |= ( fractional >> byte_shift )
123+ f8_bits |= fractional >> byte_shift
121124 else :
122125 # f16 zero / subnormal -> sign + zero exponent/mantissa
123126 # (f8_bits already contains the sign)
0 commit comments