99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111
12+ from nncf .tensor import Tensor , TensorDataType
13+ from nncf .tensor import functions as fns
1214import numpy as np
1315
1416# fmt: off
3638# fmt: on
3739
3840
39- def f16_to_f8e4m3_bits_numpy (x : np . ndarray ) -> np . ndarray :
41+ def _f16_to_f8e4m3_bits_numpy (x : Tensor ) -> Tensor :
4042 """
41- Convert an array of f16 values (or their uint16 bit patterns) to
42- f8e4m3 bit patterns (uint8) using a fully vectorized NumPy
43- port of _f16_to_f8e4m3_bits_scalar.
43+ Convert a Tensor of f16 values to f8e4m3 bit patterns (uint8).
44+ Adopted from OpenVINO C++ implementation
45+ https://github.com/openvinotoolkit/openvino/blame/master/src/core/src/type/float8_e4m3.cpp
46+
47+ :param x: Input tensor with float16 values.
48+ :return: Tensor with uint8 values representing f8e4m3 bit patterns.
4449 """
50+
51+ def to_u16_const (val : int ) -> Tensor :
52+ return fns .from_numpy (np .uint16 (val ), backend = x .backend )
53+
54+ def to_u32_const (val : int ) -> Tensor :
55+ return fns .from_numpy (np .uint32 (val ), backend = x .backend )
56+
57+ x = x .view (TensorDataType .uint16 )
58+
59+ u16_zero = to_u16_const (0 )
60+ u32_zero = to_u32_const (0 )
61+
4562 # f16 layout
46- f16_s_mask = np . uint16 (0x8000 )
47- f16_e_mask = np . uint16 (0x7C00 )
63+ f16_s_mask = to_u16_const (0x8000 )
64+ f16_e_mask = to_u16_const (0x7C00 )
4865 f16_e_bias = 15
4966 f16_e_size = 5
50- f16_m_mask = np . uint16 (0x03FF )
67+ f16_m_mask = to_u16_const (0x03FF )
5168 f16_m_size = 10
5269
5370 # f8 e4m3 layout
5471 f8e4m3_e_size = 4
55- f8e4m3_e_mask = np . uint16 (0x78 )
72+ f8e4m3_e_mask = to_u16_const (0x78 )
5673 f8e4m3_e_bias = 7
5774 f8e4m3_e_max = 0x0F
5875 f8e4m3_m_size = 3
59- f8e4m3_m_mask = np . uint16 (0x07 )
76+ f8e4m3_m_mask = to_u16_const (0x07 )
6077
6178 byte_shift = 8
6279
6380 # f8 masks in uint16 domain
64- f8_e_mask = np . uint16 (f8e4m3_e_mask << byte_shift ) # 0x7800
65- f8_m_mask = np . uint16 (f8e4m3_m_mask << byte_shift ) # 0x0700
66- f8_m_hidden_one_mask = np . uint16 (0x0800 ) # hidden 1 for subnormals
81+ f8_e_mask = (f8e4m3_e_mask << byte_shift ). astype ( TensorDataType . uint16 ) # 0x7800
82+ f8_m_mask = (f8e4m3_m_mask << byte_shift ). astype ( TensorDataType . uint16 ) # 0x0700
83+ f8_m_hidden_one_mask = to_u16_const (0x0800 ) # hidden 1 for subnormals
6784
6885 # rounding constants
69- round_half = np . uint16 (0x01FF )
70- round_norm = np . uint16 (0x007F )
71- round_even = np . uint16 (0x0080 )
72- round_odd = np . uint16 (0x0180 )
86+ round_half = to_u16_const (0x01FF )
87+ round_norm = to_u16_const (0x007F )
88+ round_even = to_u16_const (0x0080 )
89+ round_odd = to_u16_const (0x0180 )
7390
7491 # min exponent for which subnormals are representable
7592 f8_e_subnormal_min = - 10
7693
7794 # sign bit: f16 sign -> f8 sign position (bit 15 -> bit 7)
78- f8_bits = ((x & f16_s_mask ) >> byte_shift ).astype (np .uint16 )
95+ f8_bits = ((x & f16_s_mask ) >> byte_shift ).astype (TensorDataType .uint16 )
7996
8097 f16_e_field = x & f16_e_mask
8198 is_naninf = f16_e_field == f16_e_mask
82- is_zero = f16_e_field == 0
99+ is_zero = f16_e_field == u16_zero
83100 is_normal = (~ is_naninf ) & (~ is_zero )
84101
85- nan_pattern = np . uint16 (f8e4m3_e_mask | f8e4m3_m_mask )
102+ nan_pattern = (f8e4m3_e_mask | f8e4m3_m_mask ). astype ( TensorDataType . uint16 )
86103
87104 # --- Case 1: f16 NaN / Inf -> f8 NaN (no Inf) ---
88- f8_bits = np .where (is_naninf , f8_bits | nan_pattern , f8_bits )
105+ f8_bits = fns .where (is_naninf , f8_bits | nan_pattern , f8_bits )
89106
90107 # --- Case 2: normalized f16 ---
91108 # f8_biased_exp = (f16_e_field >> f16_m_size) - (f16_e_bias - f8e4m3_e_bias)
92- f8_biased_exp = (f16_e_field >> f16_m_size ).astype (np .int32 ) - (f16_e_bias - f8e4m3_e_bias )
109+ f8_biased_exp = (f16_e_field >> f16_m_size ).astype (TensorDataType .int32 ) - (f16_e_bias - f8e4m3_e_bias )
93110
94111 # fractional = (inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size)
95- fractional_norm = ((x & f16_m_mask ) << (f16_e_size - f8e4m3_e_size )).astype (np .uint16 )
112+ fractional_norm = ((x & f16_m_mask ) << (f16_e_size - f8e4m3_e_size )).astype (TensorDataType .uint16 )
96113
97114 exp_ge0 = (f8_biased_exp >= 0 ) & is_normal
98115
99116 # Rounding for normalized part (exp >= 0)
100117 # if (fractional & round_half) == round_odd or (fractional & round_norm) != 0:
101- cond_round_norm = (((fractional_norm & round_half ) == round_odd ) | ((fractional_norm & round_norm ) != 0 )) & exp_ge0
118+ cond_round_norm = (
119+ ((fractional_norm & round_half ) == round_odd ) |
120+ ((fractional_norm & round_norm ) != 0 )
121+ ) & exp_ge0
102122
103123 # fractional += round_even where cond_round_norm
104- frac_tmp = fractional_norm .astype (np .uint32 ) + np .where (cond_round_norm , round_even , np .uint16 (0 )).astype (np .uint32 )
105- fractional_norm = (frac_tmp & 0xFFFF ).astype (np .uint16 )
124+ frac_tmp = fractional_norm .astype (TensorDataType .uint32 ) + \
125+ fns .where (cond_round_norm , round_even , u16_zero ).astype (TensorDataType .uint32 )
126+ fractional_norm = (frac_tmp & 0xFFFF ).astype (TensorDataType .uint16 )
106127
107128 # if (fractional & f8_e_mask) != 0: f8_biased_exp += 1
108- exp_inc = np .where (exp_ge0 & ((fractional_norm & f8_e_mask ) != 0 ), 1 , 0 ).astype (np .int32 )
129+ exp_inc = fns .where (exp_ge0 & ((fractional_norm & f8_e_mask ) != 0 ), 1 , 0 ).astype (TensorDataType .int32 )
109130 f8_biased_exp_after = f8_biased_exp + exp_inc
110131
111132 # fractional &= f8_m_mask
@@ -119,77 +140,90 @@ def f16_to_f8e4m3_bits_numpy(x: np.ndarray) -> np.ndarray:
119140 subnormal_mask = is_normal & (f8_biased_exp_after <= 0 ) & (~ overflow_mask )
120141
121142 # --- Overflow -> NaN ---
122- f8_bits = np .where (overflow_mask , f8_bits | nan_pattern , f8_bits )
143+ f8_bits = fns .where (overflow_mask , f8_bits | nan_pattern , f8_bits )
123144
124145 # --- Normalized f8 ---
125146 # exp_field = (f8_biased_exp & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size
126- exp_field = ((f8_biased_exp_after & (f8e4m3_e_mask >> f8e4m3_m_size )) << f8e4m3_m_size ).astype (np .uint16 )
127- mant_norm = (fractional_norm >> byte_shift ).astype (np .uint16 )
147+ exp_field = ((f8_biased_exp_after & (f8e4m3_e_mask >> f8e4m3_m_size )) << f8e4m3_m_size ).astype (TensorDataType .uint16 )
148+ mant_norm = (fractional_norm >> byte_shift ).astype (TensorDataType .uint16 )
128149
129150 f8_bits_norm = f8_bits | exp_field | mant_norm
130- f8_bits = np .where (normal_mask , f8_bits_norm , f8_bits )
151+ f8_bits = fns .where (normal_mask , f8_bits_norm , f8_bits )
131152
132153 # --- Subnormal f8 ---
133154 # fractional = f8_m_hidden_one_mask | ((inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size))
134155 fractional_sub = f8_m_hidden_one_mask | ((x & f16_m_mask ) << (f16_e_size - f8e4m3_e_size ))
135156
136157 # f8_exp = f8_biased_exp - f8e4m3_e_bias
137- f8_exp = (f8_biased_exp_after - f8e4m3_e_bias ).astype (np .int32 )
158+ f8_exp = (f8_biased_exp_after - f8e4m3_e_bias ).astype (TensorDataType .int32 )
138159
139160 # shift = 1 - f8_exp
140161 shift = 1 - f8_exp
141162
142163 # sticky_mask = 0 if f8_exp < f8_e_subnormal_min else ((1 << shift) - 1)
143164 # we avoid invalid shifts by clipping / masking
144165 valid_sub = f8_exp >= f8_e_subnormal_min
145- shift_pos = np .maximum (shift , 0 )
146- sticky_mask32 = np .where (valid_sub , (np .uint32 (1 ) << shift_pos ) - 1 , 0 ).astype (np .uint32 )
147- sticky_mask16 = (sticky_mask32 & np .uint32 (0xFFFF )).astype (np .uint16 )
166+ shift_pos = fns .maximum (shift , 0 )
167+
168+ one_u32 = to_u32_const (1 )
169+ mask_u32_full = to_u32_const (0xFFFF )
170+
171+ sticky_mask32 = fns .where (
172+ valid_sub ,
173+ (one_u32 << shift_pos ) - 1 ,
174+ u32_zero ,
175+ ).astype (TensorDataType .uint32 )
176+ sticky_mask16 = (sticky_mask32 & mask_u32_full ).astype (TensorDataType .uint16 )
148177
149178 # sticky = 1 if (fractional & sticky_mask) != 0 else 0
150179 sticky = ((fractional_sub & sticky_mask16 ) != 0 ) & valid_sub
151180
152181 # fractional = 0 if f8_exp < f8_e_subnormal_min else (fractional >> (1 - f8_biased_exp))
153182 shift2 = 1 - f8_biased_exp_after
154- shift2_pos = np .maximum (shift2 , 0 )
155- frac_shifted = (fractional_sub .astype (np .uint32 ) >> shift2_pos ).astype (np .uint16 )
156- frac_shifted = np .where (valid_sub , frac_shifted , np . uint16 ( 0 ) )
183+ shift2_pos = fns .maximum (shift2 , 0 )
184+ frac_shifted = (fractional_sub .astype (TensorDataType .uint32 ) >> shift2_pos ).astype (TensorDataType .uint16 )
185+ frac_shifted = fns .where (valid_sub , frac_shifted , u16_zero )
157186
158187 # Rounding for subnormal:
159188 # if (((fractional & round_half) == round_odd and sticky == 0)
160189 # or (fractional & round_norm) != 0
161190 # or sticky != 0):
162191 cond_round_sub = (
163- (((frac_shifted & round_half ) == round_odd ) & (~ sticky )) | ((frac_shifted & round_norm ) != 0 ) | sticky
192+ (((frac_shifted & round_half ) == round_odd ) & (~ sticky )) |
193+ ((frac_shifted & round_norm ) != 0 ) |
194+ sticky
164195 ) & subnormal_mask
165196
166- frac_tmp_sub = frac_shifted .astype (np .uint32 ) + np .where (cond_round_sub , round_even , np .uint16 (0 )).astype (np .uint32 )
167- fractional_sub_final = (frac_tmp_sub & 0xFFFF ).astype (np .uint16 )
197+ frac_tmp_sub = frac_shifted .astype (TensorDataType .uint32 ) + \
198+ fns .where (cond_round_sub , round_even , u16_zero ).astype (TensorDataType .uint32 )
199+ fractional_sub_final = (frac_tmp_sub & 0xFFFF ).astype (TensorDataType .uint16 )
168200
169- mant_sub = (fractional_sub_final >> byte_shift ).astype (np .uint16 )
170- f8_bits = np .where (subnormal_mask , f8_bits | mant_sub , f8_bits )
201+ mant_sub = (fractional_sub_final >> byte_shift ).astype (TensorDataType .uint16 )
202+ f8_bits = fns .where (subnormal_mask , f8_bits | mant_sub , f8_bits )
171203
172204 # Case: f16 zero / subnormal -> sign + zero exponent/mantissa
173205 # Already handled by initialization + not touching zero_mask entries.
174206
175- return (f8_bits & np . uint16 (0x00FF )).astype (np .uint8 )
207+ return (f8_bits & to_u16_const (0x00FF )).astype (TensorDataType .uint8 )
176208
177209
178- def fp32_to_fp8e4m3 (x : np .ndarray ) -> np .ndarray :
179- """
180- Bit-exact to ov::float8_e4m3(float):
181- float32 -> float16 -> f8e4m3 bits -> float via LUT
182- """
183- x = np .asarray (x , dtype = np .float32 )
184- x_f16 = x .astype (np .float16 )
185- h_bits = x_f16 .view (np .uint16 )
186210
187- f8_bits = f16_to_f8e4m3_bits_numpy (h_bits )
188211
189- # Decode exactly like C++: LUT for magnitude + sign bit
190- idx = f8_bits & 0x7F
191- mag = F8E4M3_LUT [idx .astype (np .int32 )]
212+ def fp32_to_fp8e4m3 (x : Tensor ) -> Tensor :
213+ """
214+ Convert float32 to float8 e4m3 via float16.
215+ Adopted from OpenVINO C++ implementation
216+ https://github.com/openvinotoolkit/openvino/blame/master/src/core/src/type/float8_e4m3.cpp
192217
193- sign = np .where ((f8_bits & 0x80 ) != 0 , - 1.0 , 1.0 )
194- out = sign * mag
195- return out .astype (np .float32 )
218+ :param x: Input tensor with float32 values.
219+ :return: Tensor with float8 e4m3 values as float32 type.
220+ """
221+ x_f16 = x .astype (TensorDataType .float16 )
222+ f8_bits = _f16_to_f8e4m3_bits_numpy (x_f16 )
223+
224+ indexes = f8_bits & 0x7F
225+ look_up_table = fns .from_numpy (F8E4M3_LUT , backend = x .backend )
226+ magnitude = look_up_table [indexes .astype (TensorDataType .int32 )]
227+ sign = fns .where ((f8_bits & 0x80 ) != 0 , - 1.0 , 1.0 )
228+ result = sign * magnitude
229+ return result .astype (TensorDataType .float32 )
0 commit comments