3636# fmt: on
3737
3838
39- def _f16_to_f8e4m3_bits_scalar (h_bits : int ) -> int :
40- """Exact port of ov::f16_to_f8e4m3_bits for a single float16 bit-pattern."""
39+ def f16_to_f8e4m3_bits_numpy (x : np .ndarray ) -> np .ndarray :
40+ """
41+ Convert an array of f16 values (or their uint16 bit patterns) to
42+ f8e4m3 bit patterns (uint8) using a fully vectorized NumPy
43+ port of _f16_to_f8e4m3_bits_scalar.
44+ """
4145 # f16 layout
42- f16_s_mask = 0x8000
43- f16_e_mask = 0x7C00
46+ f16_s_mask = np . uint16 ( 0x8000 )
47+ f16_e_mask = np . uint16 ( 0x7C00 )
4448 f16_e_bias = 15
4549 f16_e_size = 5
46- f16_m_mask = 0x03FF
50+ f16_m_mask = np . uint16 ( 0x03FF )
4751 f16_m_size = 10
4852
4953 # f8 e4m3 layout
5054 f8e4m3_e_size = 4
51- f8e4m3_e_mask = 0x78
55+ f8e4m3_e_mask = np . uint16 ( 0x78 )
5256 f8e4m3_e_bias = 7
5357 f8e4m3_e_max = 0x0F
5458 f8e4m3_m_size = 3
55- f8e4m3_m_mask = 0x07
59+ f8e4m3_m_mask = np . uint16 ( 0x07 )
5660
5761 byte_shift = 8
5862
5963 # f8 masks in uint16 domain
60- f8_e_mask = f8e4m3_e_mask << byte_shift # 0x7800
61- f8_m_mask = f8e4m3_m_mask << byte_shift # 0x0700
62- f8_m_hidden_one_mask = 0x0800 # hidden 1 for subnormals
64+ f8_e_mask = np . uint16 ( f8e4m3_e_mask << byte_shift ) # 0x7800
65+ f8_m_mask = np . uint16 ( f8e4m3_m_mask << byte_shift ) # 0x0700
66+ f8_m_hidden_one_mask = np . uint16 ( 0x0800 ) # hidden 1 for subnormals
6367
64- # rounding constants (same as C++)
65- round_half = 0x01FF
66- round_norm = 0x007F
67- round_even = 0x0080
68- round_odd = 0x0180
68+ # rounding constants
69+ round_half = np . uint16 ( 0x01FF )
70+ round_norm = np . uint16 ( 0x007F )
71+ round_even = np . uint16 ( 0x0080 )
72+ round_odd = np . uint16 ( 0x0180 )
6973
7074 # min exponent for which subnormals are representable
7175 f8_e_subnormal_min = - 10
7276
73- inp = int (h_bits ) & 0xFFFF
74-
7577 # sign bit: f16 sign -> f8 sign position (bit 15 -> bit 7)
76- f8_bits = (inp & f16_s_mask ) >> byte_shift
77-
78- f16_e_field = inp & f16_e_mask
79-
80- if f16_e_field == f16_e_mask :
81- # f16 NaN / Inf -> f8 NaN (no Inf)
82- f8_bits |= f8e4m3_e_mask | f8e4m3_m_mask
83- elif f16_e_field != 0 :
84- # normalized f16
85- f8_biased_exp = (f16_e_field >> f16_m_size ) - (f16_e_bias - f8e4m3_e_bias )
86- # *** IMPORTANT FIX: shift by (f16_e_size - f8e4m3_e_size) = 5 - 4 = 1 ***
87- fractional = (inp & f16_m_mask ) << (f16_e_size - f8e4m3_e_size )
88-
89- # normalized f8 part (exp >= 0)
90- if f8_biased_exp >= 0 :
91- if (fractional & round_half ) == round_odd or (fractional & round_norm ) != 0 :
92- fractional += round_even
93- if (fractional & f8_e_mask ) != 0 :
94- f8_biased_exp += 1
95- fractional &= f8_m_mask
96-
97- # now set exponent & mantissa
98- if f8_biased_exp > f8e4m3_e_max :
99- # overflow -> NaN (no Inf)
100- f8_bits |= f8e4m3_e_mask | f8e4m3_m_mask
101- elif f8_biased_exp > 0 :
102- # normalized f8
103- exp_field = (f8_biased_exp & (f8e4m3_e_mask >> f8e4m3_m_size )) << f8e4m3_m_size
104- f8_bits |= exp_field
105- f8_bits |= fractional >> byte_shift
106- else :
107- # subnormal f8
108- fractional = f8_m_hidden_one_mask | ((inp & f16_m_mask ) << (f16_e_size - f8e4m3_e_size ))
109- f8_exp = f8_biased_exp - f8e4m3_e_bias
110- shift = 1 - f8_exp
111- sticky_mask = 0 if f8_exp < f8_e_subnormal_min else ((1 << shift ) - 1 )
112- sticky = 1 if (fractional & sticky_mask ) != 0 else 0
113-
114- fractional = 0 if f8_exp < f8_e_subnormal_min else (fractional >> (1 - f8_biased_exp ))
115-
116- if (
117- ((fractional & round_half ) == round_odd and sticky == 0 )
118- or (fractional & round_norm ) != 0
119- or sticky != 0
120- ):
121- fractional += round_even
122-
123- f8_bits |= fractional >> byte_shift
124- else :
125- # f16 zero / subnormal -> sign + zero exponent/mantissa
126- # (f8_bits already contains the sign)
127- pass
128-
129- return f8_bits & 0xFF
130-
131-
132- _f16_to_f8e4m3_bits_vec = np .vectorize (_f16_to_f8e4m3_bits_scalar , otypes = [np .uint8 ])
133-
134-
135- def fp32_to_fp8e4m3_values (x : np .ndarray ) -> np .ndarray :
78+ f8_bits = ((x & f16_s_mask ) >> byte_shift ).astype (np .uint16 )
79+
80+ f16_e_field = x & f16_e_mask
81+ is_naninf = f16_e_field == f16_e_mask
82+ is_zero = f16_e_field == 0
83+ is_normal = (~ is_naninf ) & (~ is_zero )
84+
85+ nan_pattern = np .uint16 (f8e4m3_e_mask | f8e4m3_m_mask )
86+
87+ # --- Case 1: f16 NaN / Inf -> f8 NaN (no Inf) ---
88+ f8_bits = np .where (is_naninf , f8_bits | nan_pattern , f8_bits )
89+
90+ # --- Case 2: normalized f16 ---
91+ # f8_biased_exp = (f16_e_field >> f16_m_size) - (f16_e_bias - f8e4m3_e_bias)
92+ f8_biased_exp = (f16_e_field >> f16_m_size ).astype (np .int32 ) - (f16_e_bias - f8e4m3_e_bias )
93+
94+ # fractional = (inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size)
95+ fractional_norm = ((x & f16_m_mask ) << (f16_e_size - f8e4m3_e_size )).astype (np .uint16 )
96+
97+ exp_ge0 = (f8_biased_exp >= 0 ) & is_normal
98+
99+ # Rounding for normalized part (exp >= 0)
100+ # if (fractional & round_half) == round_odd or (fractional & round_norm) != 0:
101+ cond_round_norm = (((fractional_norm & round_half ) == round_odd ) | ((fractional_norm & round_norm ) != 0 )) & exp_ge0
102+
103+ # fractional += round_even where cond_round_norm
104+ frac_tmp = fractional_norm .astype (np .uint32 ) + np .where (cond_round_norm , round_even , np .uint16 (0 )).astype (np .uint32 )
105+ fractional_norm = (frac_tmp & 0xFFFF ).astype (np .uint16 )
106+
107+ # if (fractional & f8_e_mask) != 0: f8_biased_exp += 1
108+ exp_inc = np .where (exp_ge0 & ((fractional_norm & f8_e_mask ) != 0 ), 1 , 0 ).astype (np .int32 )
109+ f8_biased_exp_after = f8_biased_exp + exp_inc
110+
111+ # fractional &= f8_m_mask
112+ fractional_norm &= f8_m_mask
113+
114+ # Overflow / normalized / subnormal classification
115+ overflow_mask = is_normal & (f8_biased_exp_after > f8e4m3_e_max )
116+ normal_mask = is_normal & (f8_biased_exp_after > 0 ) & (~ overflow_mask )
117+ # For subnormals, the scalar code uses f8_biased_exp (after possible increment),
118+ # but increment is only applied when exp >= 0, so exp <= 0 path is unchanged.
119+ subnormal_mask = is_normal & (f8_biased_exp_after <= 0 ) & (~ overflow_mask )
120+
121+ # --- Overflow -> NaN ---
122+ f8_bits = np .where (overflow_mask , f8_bits | nan_pattern , f8_bits )
123+
124+ # --- Normalized f8 ---
125+ # exp_field = (f8_biased_exp & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size
126+ exp_field = ((f8_biased_exp_after & (f8e4m3_e_mask >> f8e4m3_m_size )) << f8e4m3_m_size ).astype (np .uint16 )
127+ mant_norm = (fractional_norm >> byte_shift ).astype (np .uint16 )
128+
129+ f8_bits_norm = f8_bits | exp_field | mant_norm
130+ f8_bits = np .where (normal_mask , f8_bits_norm , f8_bits )
131+
132+ # --- Subnormal f8 ---
133+ # fractional = f8_m_hidden_one_mask | ((inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size))
134+ fractional_sub = f8_m_hidden_one_mask | ((x & f16_m_mask ) << (f16_e_size - f8e4m3_e_size ))
135+
136+ # f8_exp = f8_biased_exp - f8e4m3_e_bias
137+ f8_exp = (f8_biased_exp_after - f8e4m3_e_bias ).astype (np .int32 )
138+
139+ # shift = 1 - f8_exp
140+ shift = 1 - f8_exp
141+
142+ # sticky_mask = 0 if f8_exp < f8_e_subnormal_min else ((1 << shift) - 1)
143+ # we avoid invalid shifts by clipping / masking
144+ valid_sub = f8_exp >= f8_e_subnormal_min
145+ shift_pos = np .maximum (shift , 0 )
146+ sticky_mask32 = np .where (valid_sub , (np .uint32 (1 ) << shift_pos ) - 1 , 0 ).astype (np .uint32 )
147+ sticky_mask16 = (sticky_mask32 & np .uint32 (0xFFFF )).astype (np .uint16 )
148+
149+ # sticky = 1 if (fractional & sticky_mask) != 0 else 0
150+ sticky = ((fractional_sub & sticky_mask16 ) != 0 ) & valid_sub
151+
152+ # fractional = 0 if f8_exp < f8_e_subnormal_min else (fractional >> (1 - f8_biased_exp))
153+ shift2 = 1 - f8_biased_exp_after
154+ shift2_pos = np .maximum (shift2 , 0 )
155+ frac_shifted = (fractional_sub .astype (np .uint32 ) >> shift2_pos ).astype (np .uint16 )
156+ frac_shifted = np .where (valid_sub , frac_shifted , np .uint16 (0 ))
157+
158+ # Rounding for subnormal:
159+ # if (((fractional & round_half) == round_odd and sticky == 0)
160+ # or (fractional & round_norm) != 0
161+ # or sticky != 0):
162+ cond_round_sub = (
163+ (((frac_shifted & round_half ) == round_odd ) & (~ sticky )) | ((frac_shifted & round_norm ) != 0 ) | sticky
164+ ) & subnormal_mask
165+
166+ frac_tmp_sub = frac_shifted .astype (np .uint32 ) + np .where (cond_round_sub , round_even , np .uint16 (0 )).astype (np .uint32 )
167+ fractional_sub_final = (frac_tmp_sub & 0xFFFF ).astype (np .uint16 )
168+
169+ mant_sub = (fractional_sub_final >> byte_shift ).astype (np .uint16 )
170+ f8_bits = np .where (subnormal_mask , f8_bits | mant_sub , f8_bits )
171+
172+ # Case: f16 zero / subnormal -> sign + zero exponent/mantissa
173+ # Already handled by initialization + not touching zero_mask entries.
174+
175+ return (f8_bits & np .uint16 (0x00FF )).astype (np .uint8 )
176+
177+
178+ def fp32_to_fp8e4m3 (x : np .ndarray ) -> np .ndarray :
136179 """
137180 Bit-exact to ov::float8_e4m3(float):
138181 float32 -> float16 -> f8e4m3 bits -> float via LUT
@@ -141,7 +184,7 @@ def fp32_to_fp8e4m3_values(x: np.ndarray) -> np.ndarray:
141184 x_f16 = x .astype (np .float16 )
142185 h_bits = x_f16 .view (np .uint16 )
143186
144- f8_bits = _f16_to_f8e4m3_bits_vec (h_bits )
187+ f8_bits = f16_to_f8e4m3_bits_numpy (h_bits )
145188
146189 # Decode exactly like C++: LUT for magnitude + sign bit
147190 idx = f8_bits & 0x7F
0 commit comments