Skip to content

Commit 0266a3e

Browse files
Add pure numpy conversion
1 parent 948fad9 commit 0266a3e

File tree

2 files changed

+123
-80
lines changed

2 files changed

+123
-80
lines changed

src/nncf/quantization/algorithms/weight_compression/fp8_conversion.py

Lines changed: 121 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -36,103 +36,146 @@
3636
# fmt: on
3737

3838

39-
def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
40-
"""Exact port of ov::f16_to_f8e4m3_bits for a single float16 bit-pattern."""
39+
def f16_to_f8e4m3_bits_numpy(x: np.ndarray) -> np.ndarray:
40+
"""
41+
Convert an array of f16 values (or their uint16 bit patterns) to
42+
f8e4m3 bit patterns (uint8) using a fully vectorized NumPy
43+
port of _f16_to_f8e4m3_bits_scalar.
44+
"""
4145
# f16 layout
42-
f16_s_mask = 0x8000
43-
f16_e_mask = 0x7C00
46+
f16_s_mask = np.uint16(0x8000)
47+
f16_e_mask = np.uint16(0x7C00)
4448
f16_e_bias = 15
4549
f16_e_size = 5
46-
f16_m_mask = 0x03FF
50+
f16_m_mask = np.uint16(0x03FF)
4751
f16_m_size = 10
4852

4953
# f8 e4m3 layout
5054
f8e4m3_e_size = 4
51-
f8e4m3_e_mask = 0x78
55+
f8e4m3_e_mask = np.uint16(0x78)
5256
f8e4m3_e_bias = 7
5357
f8e4m3_e_max = 0x0F
5458
f8e4m3_m_size = 3
55-
f8e4m3_m_mask = 0x07
59+
f8e4m3_m_mask = np.uint16(0x07)
5660

5761
byte_shift = 8
5862

5963
# f8 masks in uint16 domain
60-
f8_e_mask = f8e4m3_e_mask << byte_shift # 0x7800
61-
f8_m_mask = f8e4m3_m_mask << byte_shift # 0x0700
62-
f8_m_hidden_one_mask = 0x0800 # hidden 1 for subnormals
64+
f8_e_mask = np.uint16(f8e4m3_e_mask << byte_shift) # 0x7800
65+
f8_m_mask = np.uint16(f8e4m3_m_mask << byte_shift) # 0x0700
66+
f8_m_hidden_one_mask = np.uint16(0x0800) # hidden 1 for subnormals
6367

64-
# rounding constants (same as C++)
65-
round_half = 0x01FF
66-
round_norm = 0x007F
67-
round_even = 0x0080
68-
round_odd = 0x0180
68+
# rounding constants
69+
round_half = np.uint16(0x01FF)
70+
round_norm = np.uint16(0x007F)
71+
round_even = np.uint16(0x0080)
72+
round_odd = np.uint16(0x0180)
6973

7074
# min exponent for which subnormals are representable
7175
f8_e_subnormal_min = -10
7276

73-
inp = int(h_bits) & 0xFFFF
74-
7577
# sign bit: f16 sign -> f8 sign position (bit 15 -> bit 7)
76-
f8_bits = (inp & f16_s_mask) >> byte_shift
77-
78-
f16_e_field = inp & f16_e_mask
79-
80-
if f16_e_field == f16_e_mask:
81-
# f16 NaN / Inf -> f8 NaN (no Inf)
82-
f8_bits |= f8e4m3_e_mask | f8e4m3_m_mask
83-
elif f16_e_field != 0:
84-
# normalized f16
85-
f8_biased_exp = (f16_e_field >> f16_m_size) - (f16_e_bias - f8e4m3_e_bias)
86-
# *** IMPORTANT FIX: shift by (f16_e_size - f8e4m3_e_size) = 5 - 4 = 1 ***
87-
fractional = (inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size)
88-
89-
# normalized f8 part (exp >= 0)
90-
if f8_biased_exp >= 0:
91-
if (fractional & round_half) == round_odd or (fractional & round_norm) != 0:
92-
fractional += round_even
93-
if (fractional & f8_e_mask) != 0:
94-
f8_biased_exp += 1
95-
fractional &= f8_m_mask
96-
97-
# now set exponent & mantissa
98-
if f8_biased_exp > f8e4m3_e_max:
99-
# overflow -> NaN (no Inf)
100-
f8_bits |= f8e4m3_e_mask | f8e4m3_m_mask
101-
elif f8_biased_exp > 0:
102-
# normalized f8
103-
exp_field = (f8_biased_exp & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size
104-
f8_bits |= exp_field
105-
f8_bits |= fractional >> byte_shift
106-
else:
107-
# subnormal f8
108-
fractional = f8_m_hidden_one_mask | ((inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size))
109-
f8_exp = f8_biased_exp - f8e4m3_e_bias
110-
shift = 1 - f8_exp
111-
sticky_mask = 0 if f8_exp < f8_e_subnormal_min else ((1 << shift) - 1)
112-
sticky = 1 if (fractional & sticky_mask) != 0 else 0
113-
114-
fractional = 0 if f8_exp < f8_e_subnormal_min else (fractional >> (1 - f8_biased_exp))
115-
116-
if (
117-
((fractional & round_half) == round_odd and sticky == 0)
118-
or (fractional & round_norm) != 0
119-
or sticky != 0
120-
):
121-
fractional += round_even
122-
123-
f8_bits |= fractional >> byte_shift
124-
else:
125-
# f16 zero / subnormal -> sign + zero exponent/mantissa
126-
# (f8_bits already contains the sign)
127-
pass
128-
129-
return f8_bits & 0xFF
130-
131-
132-
_f16_to_f8e4m3_bits_vec = np.vectorize(_f16_to_f8e4m3_bits_scalar, otypes=[np.uint8])
133-
134-
135-
def fp32_to_fp8e4m3_values(x: np.ndarray) -> np.ndarray:
78+
f8_bits = ((x & f16_s_mask) >> byte_shift).astype(np.uint16)
79+
80+
f16_e_field = x & f16_e_mask
81+
is_naninf = f16_e_field == f16_e_mask
82+
is_zero = f16_e_field == 0
83+
is_normal = (~is_naninf) & (~is_zero)
84+
85+
nan_pattern = np.uint16(f8e4m3_e_mask | f8e4m3_m_mask)
86+
87+
# --- Case 1: f16 NaN / Inf -> f8 NaN (no Inf) ---
88+
f8_bits = np.where(is_naninf, f8_bits | nan_pattern, f8_bits)
89+
90+
# --- Case 2: normalized f16 ---
91+
# f8_biased_exp = (f16_e_field >> f16_m_size) - (f16_e_bias - f8e4m3_e_bias)
92+
f8_biased_exp = (f16_e_field >> f16_m_size).astype(np.int32) - (f16_e_bias - f8e4m3_e_bias)
93+
94+
# fractional = (inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size)
95+
fractional_norm = ((x & f16_m_mask) << (f16_e_size - f8e4m3_e_size)).astype(np.uint16)
96+
97+
exp_ge0 = (f8_biased_exp >= 0) & is_normal
98+
99+
# Rounding for normalized part (exp >= 0)
100+
# if (fractional & round_half) == round_odd or (fractional & round_norm) != 0:
101+
cond_round_norm = (((fractional_norm & round_half) == round_odd) | ((fractional_norm & round_norm) != 0)) & exp_ge0
102+
103+
# fractional += round_even where cond_round_norm
104+
frac_tmp = fractional_norm.astype(np.uint32) + np.where(cond_round_norm, round_even, np.uint16(0)).astype(np.uint32)
105+
fractional_norm = (frac_tmp & 0xFFFF).astype(np.uint16)
106+
107+
# if (fractional & f8_e_mask) != 0: f8_biased_exp += 1
108+
exp_inc = np.where(exp_ge0 & ((fractional_norm & f8_e_mask) != 0), 1, 0).astype(np.int32)
109+
f8_biased_exp_after = f8_biased_exp + exp_inc
110+
111+
# fractional &= f8_m_mask
112+
fractional_norm &= f8_m_mask
113+
114+
# Overflow / normalized / subnormal classification
115+
overflow_mask = is_normal & (f8_biased_exp_after > f8e4m3_e_max)
116+
normal_mask = is_normal & (f8_biased_exp_after > 0) & (~overflow_mask)
117+
# For subnormals, the scalar code uses f8_biased_exp (after possible increment),
118+
# but increment is only applied when exp >= 0, so exp <= 0 path is unchanged.
119+
subnormal_mask = is_normal & (f8_biased_exp_after <= 0) & (~overflow_mask)
120+
121+
# --- Overflow -> NaN ---
122+
f8_bits = np.where(overflow_mask, f8_bits | nan_pattern, f8_bits)
123+
124+
# --- Normalized f8 ---
125+
# exp_field = (f8_biased_exp & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size
126+
exp_field = ((f8_biased_exp_after & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size).astype(np.uint16)
127+
mant_norm = (fractional_norm >> byte_shift).astype(np.uint16)
128+
129+
f8_bits_norm = f8_bits | exp_field | mant_norm
130+
f8_bits = np.where(normal_mask, f8_bits_norm, f8_bits)
131+
132+
# --- Subnormal f8 ---
133+
# fractional = f8_m_hidden_one_mask | ((inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size))
134+
fractional_sub = f8_m_hidden_one_mask | ((x & f16_m_mask) << (f16_e_size - f8e4m3_e_size))
135+
136+
# f8_exp = f8_biased_exp - f8e4m3_e_bias
137+
f8_exp = (f8_biased_exp_after - f8e4m3_e_bias).astype(np.int32)
138+
139+
# shift = 1 - f8_exp
140+
shift = 1 - f8_exp
141+
142+
# sticky_mask = 0 if f8_exp < f8_e_subnormal_min else ((1 << shift) - 1)
143+
# we avoid invalid shifts by clipping / masking
144+
valid_sub = f8_exp >= f8_e_subnormal_min
145+
shift_pos = np.maximum(shift, 0)
146+
sticky_mask32 = np.where(valid_sub, (np.uint32(1) << shift_pos) - 1, 0).astype(np.uint32)
147+
sticky_mask16 = (sticky_mask32 & np.uint32(0xFFFF)).astype(np.uint16)
148+
149+
# sticky = 1 if (fractional & sticky_mask) != 0 else 0
150+
sticky = ((fractional_sub & sticky_mask16) != 0) & valid_sub
151+
152+
# fractional = 0 if f8_exp < f8_e_subnormal_min else (fractional >> (1 - f8_biased_exp))
153+
shift2 = 1 - f8_biased_exp_after
154+
shift2_pos = np.maximum(shift2, 0)
155+
frac_shifted = (fractional_sub.astype(np.uint32) >> shift2_pos).astype(np.uint16)
156+
frac_shifted = np.where(valid_sub, frac_shifted, np.uint16(0))
157+
158+
# Rounding for subnormal:
159+
# if (((fractional & round_half) == round_odd and sticky == 0)
160+
# or (fractional & round_norm) != 0
161+
# or sticky != 0):
162+
cond_round_sub = (
163+
(((frac_shifted & round_half) == round_odd) & (~sticky)) | ((frac_shifted & round_norm) != 0) | sticky
164+
) & subnormal_mask
165+
166+
frac_tmp_sub = frac_shifted.astype(np.uint32) + np.where(cond_round_sub, round_even, np.uint16(0)).astype(np.uint32)
167+
fractional_sub_final = (frac_tmp_sub & 0xFFFF).astype(np.uint16)
168+
169+
mant_sub = (fractional_sub_final >> byte_shift).astype(np.uint16)
170+
f8_bits = np.where(subnormal_mask, f8_bits | mant_sub, f8_bits)
171+
172+
# Case: f16 zero / subnormal -> sign + zero exponent/mantissa
173+
# Already handled by initialization + not touching zero_mask entries.
174+
175+
return (f8_bits & np.uint16(0x00FF)).astype(np.uint8)
176+
177+
178+
def fp32_to_fp8e4m3(x: np.ndarray) -> np.ndarray:
136179
"""
137180
Bit-exact to ov::float8_e4m3(float):
138181
float32 -> float16 -> f8e4m3 bits -> float via LUT
@@ -141,7 +184,7 @@ def fp32_to_fp8e4m3_values(x: np.ndarray) -> np.ndarray:
141184
x_f16 = x.astype(np.float16)
142185
h_bits = x_f16.view(np.uint16)
143186

144-
f8_bits = _f16_to_f8e4m3_bits_vec(h_bits)
187+
f8_bits = f16_to_f8e4m3_bits_numpy(h_bits)
145188

146189
# Decode exactly like C++: LUT for magnitude + sign bit
147190
idx = f8_bits & 0x7F

src/nncf/quantization/algorithms/weight_compression/weight_lowering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,9 +520,9 @@ def _calculate_float_quantized_weight(norm_weight: Tensor, compression_dtype: Te
520520
assert compression_dtype in [TensorDataType.f8e4m3, TensorDataType.f4e2m1, TensorDataType.nf4]
521521

522522
if compression_dtype == TensorDataType.f8e4m3:
523-
from nncf.quantization.algorithms.weight_compression.fp8_conversion import fp32_to_fp8e4m3_values
523+
from nncf.quantization.algorithms.weight_compression.fp8_conversion import fp32_to_fp8e4m3
524524

525-
quantiles_np = fp32_to_fp8e4m3_values(norm_weight.as_numpy_tensor().data)
525+
quantiles_np = fp32_to_fp8e4m3(norm_weight.as_numpy_tensor().data)
526526
return fns.from_numpy(quantiles_np, backend=norm_weight.backend)
527527

528528
is_nf4 = compression_dtype == TensorDataType.nf4

0 commit comments

Comments
 (0)