Skip to content

Commit d5cc3f8

Browse files
Rewrite f32->f8 conversion using nncf.Tensor
1 parent 0266a3e commit d5cc3f8

File tree

11 files changed

+234
-69
lines changed

11 files changed

+234
-69
lines changed

src/nncf/quantization/algorithms/weight_compression/fp8_conversion.py

Lines changed: 92 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
from nncf.tensor import Tensor, TensorDataType
13+
from nncf.tensor import functions as fns
1214
import numpy as np
1315

1416
# fmt: off
@@ -36,76 +38,95 @@
3638
# fmt: on
3739

3840

39-
def f16_to_f8e4m3_bits_numpy(x: np.ndarray) -> np.ndarray:
41+
def _f16_to_f8e4m3_bits_numpy(x: Tensor) -> Tensor:
4042
"""
41-
Convert an array of f16 values (or their uint16 bit patterns) to
42-
f8e4m3 bit patterns (uint8) using a fully vectorized NumPy
43-
port of _f16_to_f8e4m3_bits_scalar.
43+
Convert a Tensor of f16 values to f8e4m3 bit patterns (uint8).
44+
Adopted from OpenVINO C++ implementation
45+
https://github.com/openvinotoolkit/openvino/blame/master/src/core/src/type/float8_e4m3.cpp
46+
47+
:param x: Input tensor with float16 values.
48+
:return: Tensor with uint8 values representing f8e4m3 bit patterns.
4449
"""
50+
51+
def to_u16_const(val: int) -> Tensor:
52+
return fns.from_numpy(np.uint16(val), backend=x.backend)
53+
54+
def to_u32_const(val: int) -> Tensor:
55+
return fns.from_numpy(np.uint32(val), backend=x.backend)
56+
57+
x = x.view(TensorDataType.uint16)
58+
59+
u16_zero = to_u16_const(0)
60+
u32_zero = to_u32_const(0)
61+
4562
# f16 layout
46-
f16_s_mask = np.uint16(0x8000)
47-
f16_e_mask = np.uint16(0x7C00)
63+
f16_s_mask = to_u16_const(0x8000)
64+
f16_e_mask = to_u16_const(0x7C00)
4865
f16_e_bias = 15
4966
f16_e_size = 5
50-
f16_m_mask = np.uint16(0x03FF)
67+
f16_m_mask = to_u16_const(0x03FF)
5168
f16_m_size = 10
5269

5370
# f8 e4m3 layout
5471
f8e4m3_e_size = 4
55-
f8e4m3_e_mask = np.uint16(0x78)
72+
f8e4m3_e_mask = to_u16_const(0x78)
5673
f8e4m3_e_bias = 7
5774
f8e4m3_e_max = 0x0F
5875
f8e4m3_m_size = 3
59-
f8e4m3_m_mask = np.uint16(0x07)
76+
f8e4m3_m_mask = to_u16_const(0x07)
6077

6178
byte_shift = 8
6279

6380
# f8 masks in uint16 domain
64-
f8_e_mask = np.uint16(f8e4m3_e_mask << byte_shift) # 0x7800
65-
f8_m_mask = np.uint16(f8e4m3_m_mask << byte_shift) # 0x0700
66-
f8_m_hidden_one_mask = np.uint16(0x0800) # hidden 1 for subnormals
81+
f8_e_mask = (f8e4m3_e_mask << byte_shift).astype(TensorDataType.uint16) # 0x7800
82+
f8_m_mask = (f8e4m3_m_mask << byte_shift).astype(TensorDataType.uint16) # 0x0700
83+
f8_m_hidden_one_mask = to_u16_const(0x0800) # hidden 1 for subnormals
6784

6885
# rounding constants
69-
round_half = np.uint16(0x01FF)
70-
round_norm = np.uint16(0x007F)
71-
round_even = np.uint16(0x0080)
72-
round_odd = np.uint16(0x0180)
86+
round_half = to_u16_const(0x01FF)
87+
round_norm = to_u16_const(0x007F)
88+
round_even = to_u16_const(0x0080)
89+
round_odd = to_u16_const(0x0180)
7390

7491
# min exponent for which subnormals are representable
7592
f8_e_subnormal_min = -10
7693

7794
# sign bit: f16 sign -> f8 sign position (bit 15 -> bit 7)
78-
f8_bits = ((x & f16_s_mask) >> byte_shift).astype(np.uint16)
95+
f8_bits = ((x & f16_s_mask) >> byte_shift).astype(TensorDataType.uint16)
7996

8097
f16_e_field = x & f16_e_mask
8198
is_naninf = f16_e_field == f16_e_mask
82-
is_zero = f16_e_field == 0
99+
is_zero = f16_e_field == u16_zero
83100
is_normal = (~is_naninf) & (~is_zero)
84101

85-
nan_pattern = np.uint16(f8e4m3_e_mask | f8e4m3_m_mask)
102+
nan_pattern = (f8e4m3_e_mask | f8e4m3_m_mask).astype(TensorDataType.uint16)
86103

87104
# --- Case 1: f16 NaN / Inf -> f8 NaN (no Inf) ---
88-
f8_bits = np.where(is_naninf, f8_bits | nan_pattern, f8_bits)
105+
f8_bits = fns.where(is_naninf, f8_bits | nan_pattern, f8_bits)
89106

90107
# --- Case 2: normalized f16 ---
91108
# f8_biased_exp = (f16_e_field >> f16_m_size) - (f16_e_bias - f8e4m3_e_bias)
92-
f8_biased_exp = (f16_e_field >> f16_m_size).astype(np.int32) - (f16_e_bias - f8e4m3_e_bias)
109+
f8_biased_exp = (f16_e_field >> f16_m_size).astype(TensorDataType.int32) - (f16_e_bias - f8e4m3_e_bias)
93110

94111
# fractional = (inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size)
95-
fractional_norm = ((x & f16_m_mask) << (f16_e_size - f8e4m3_e_size)).astype(np.uint16)
112+
fractional_norm = ((x & f16_m_mask) << (f16_e_size - f8e4m3_e_size)).astype(TensorDataType.uint16)
96113

97114
exp_ge0 = (f8_biased_exp >= 0) & is_normal
98115

99116
# Rounding for normalized part (exp >= 0)
100117
# if (fractional & round_half) == round_odd or (fractional & round_norm) != 0:
101-
cond_round_norm = (((fractional_norm & round_half) == round_odd) | ((fractional_norm & round_norm) != 0)) & exp_ge0
118+
cond_round_norm = (
119+
((fractional_norm & round_half) == round_odd) |
120+
((fractional_norm & round_norm) != 0)
121+
) & exp_ge0
102122

103123
# fractional += round_even where cond_round_norm
104-
frac_tmp = fractional_norm.astype(np.uint32) + np.where(cond_round_norm, round_even, np.uint16(0)).astype(np.uint32)
105-
fractional_norm = (frac_tmp & 0xFFFF).astype(np.uint16)
124+
frac_tmp = fractional_norm.astype(TensorDataType.uint32) + \
125+
fns.where(cond_round_norm, round_even, u16_zero).astype(TensorDataType.uint32)
126+
fractional_norm = (frac_tmp & 0xFFFF).astype(TensorDataType.uint16)
106127

107128
# if (fractional & f8_e_mask) != 0: f8_biased_exp += 1
108-
exp_inc = np.where(exp_ge0 & ((fractional_norm & f8_e_mask) != 0), 1, 0).astype(np.int32)
129+
exp_inc = fns.where(exp_ge0 & ((fractional_norm & f8_e_mask) != 0), 1, 0).astype(TensorDataType.int32)
109130
f8_biased_exp_after = f8_biased_exp + exp_inc
110131

111132
# fractional &= f8_m_mask
@@ -119,77 +140,90 @@ def f16_to_f8e4m3_bits_numpy(x: np.ndarray) -> np.ndarray:
119140
subnormal_mask = is_normal & (f8_biased_exp_after <= 0) & (~overflow_mask)
120141

121142
# --- Overflow -> NaN ---
122-
f8_bits = np.where(overflow_mask, f8_bits | nan_pattern, f8_bits)
143+
f8_bits = fns.where(overflow_mask, f8_bits | nan_pattern, f8_bits)
123144

124145
# --- Normalized f8 ---
125146
# exp_field = (f8_biased_exp & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size
126-
exp_field = ((f8_biased_exp_after & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size).astype(np.uint16)
127-
mant_norm = (fractional_norm >> byte_shift).astype(np.uint16)
147+
exp_field = ((f8_biased_exp_after & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size).astype(TensorDataType.uint16)
148+
mant_norm = (fractional_norm >> byte_shift).astype(TensorDataType.uint16)
128149

129150
f8_bits_norm = f8_bits | exp_field | mant_norm
130-
f8_bits = np.where(normal_mask, f8_bits_norm, f8_bits)
151+
f8_bits = fns.where(normal_mask, f8_bits_norm, f8_bits)
131152

132153
# --- Subnormal f8 ---
133154
# fractional = f8_m_hidden_one_mask | ((inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size))
134155
fractional_sub = f8_m_hidden_one_mask | ((x & f16_m_mask) << (f16_e_size - f8e4m3_e_size))
135156

136157
# f8_exp = f8_biased_exp - f8e4m3_e_bias
137-
f8_exp = (f8_biased_exp_after - f8e4m3_e_bias).astype(np.int32)
158+
f8_exp = (f8_biased_exp_after - f8e4m3_e_bias).astype(TensorDataType.int32)
138159

139160
# shift = 1 - f8_exp
140161
shift = 1 - f8_exp
141162

142163
# sticky_mask = 0 if f8_exp < f8_e_subnormal_min else ((1 << shift) - 1)
143164
# we avoid invalid shifts by clipping / masking
144165
valid_sub = f8_exp >= f8_e_subnormal_min
145-
shift_pos = np.maximum(shift, 0)
146-
sticky_mask32 = np.where(valid_sub, (np.uint32(1) << shift_pos) - 1, 0).astype(np.uint32)
147-
sticky_mask16 = (sticky_mask32 & np.uint32(0xFFFF)).astype(np.uint16)
166+
shift_pos = fns.maximum(shift, 0)
167+
168+
one_u32 = to_u32_const(1)
169+
mask_u32_full = to_u32_const(0xFFFF)
170+
171+
sticky_mask32 = fns.where(
172+
valid_sub,
173+
(one_u32 << shift_pos) - 1,
174+
u32_zero,
175+
).astype(TensorDataType.uint32)
176+
sticky_mask16 = (sticky_mask32 & mask_u32_full).astype(TensorDataType.uint16)
148177

149178
# sticky = 1 if (fractional & sticky_mask) != 0 else 0
150179
sticky = ((fractional_sub & sticky_mask16) != 0) & valid_sub
151180

152181
# fractional = 0 if f8_exp < f8_e_subnormal_min else (fractional >> (1 - f8_biased_exp))
153182
shift2 = 1 - f8_biased_exp_after
154-
shift2_pos = np.maximum(shift2, 0)
155-
frac_shifted = (fractional_sub.astype(np.uint32) >> shift2_pos).astype(np.uint16)
156-
frac_shifted = np.where(valid_sub, frac_shifted, np.uint16(0))
183+
shift2_pos = fns.maximum(shift2, 0)
184+
frac_shifted = (fractional_sub.astype(TensorDataType.uint32) >> shift2_pos).astype(TensorDataType.uint16)
185+
frac_shifted = fns.where(valid_sub, frac_shifted, u16_zero)
157186

158187
# Rounding for subnormal:
159188
# if (((fractional & round_half) == round_odd and sticky == 0)
160189
# or (fractional & round_norm) != 0
161190
# or sticky != 0):
162191
cond_round_sub = (
163-
(((frac_shifted & round_half) == round_odd) & (~sticky)) | ((frac_shifted & round_norm) != 0) | sticky
192+
(((frac_shifted & round_half) == round_odd) & (~sticky)) |
193+
((frac_shifted & round_norm) != 0) |
194+
sticky
164195
) & subnormal_mask
165196

166-
frac_tmp_sub = frac_shifted.astype(np.uint32) + np.where(cond_round_sub, round_even, np.uint16(0)).astype(np.uint32)
167-
fractional_sub_final = (frac_tmp_sub & 0xFFFF).astype(np.uint16)
197+
frac_tmp_sub = frac_shifted.astype(TensorDataType.uint32) + \
198+
fns.where(cond_round_sub, round_even, u16_zero).astype(TensorDataType.uint32)
199+
fractional_sub_final = (frac_tmp_sub & 0xFFFF).astype(TensorDataType.uint16)
168200

169-
mant_sub = (fractional_sub_final >> byte_shift).astype(np.uint16)
170-
f8_bits = np.where(subnormal_mask, f8_bits | mant_sub, f8_bits)
201+
mant_sub = (fractional_sub_final >> byte_shift).astype(TensorDataType.uint16)
202+
f8_bits = fns.where(subnormal_mask, f8_bits | mant_sub, f8_bits)
171203

172204
# Case: f16 zero / subnormal -> sign + zero exponent/mantissa
173205
# Already handled by initialization + not touching zero_mask entries.
174206

175-
return (f8_bits & np.uint16(0x00FF)).astype(np.uint8)
207+
return (f8_bits & to_u16_const(0x00FF)).astype(TensorDataType.uint8)
176208

177209

178-
def fp32_to_fp8e4m3(x: np.ndarray) -> np.ndarray:
179-
"""
180-
Bit-exact to ov::float8_e4m3(float):
181-
float32 -> float16 -> f8e4m3 bits -> float via LUT
182-
"""
183-
x = np.asarray(x, dtype=np.float32)
184-
x_f16 = x.astype(np.float16)
185-
h_bits = x_f16.view(np.uint16)
186210

187-
f8_bits = f16_to_f8e4m3_bits_numpy(h_bits)
188211

189-
# Decode exactly like C++: LUT for magnitude + sign bit
190-
idx = f8_bits & 0x7F
191-
mag = F8E4M3_LUT[idx.astype(np.int32)]
212+
def fp32_to_fp8e4m3(x: Tensor) -> Tensor:
213+
"""
214+
Convert float32 to float8 e4m3 via float16.
215+
Adopted from OpenVINO C++ implementation
216+
https://github.com/openvinotoolkit/openvino/blame/master/src/core/src/type/float8_e4m3.cpp
192217
193-
sign = np.where((f8_bits & 0x80) != 0, -1.0, 1.0)
194-
out = sign * mag
195-
return out.astype(np.float32)
218+
:param x: Input tensor with float32 values.
219+
:return: Tensor with float8 e4m3 values as float32 type.
220+
"""
221+
x_f16 = x.astype(TensorDataType.float16)
222+
f8_bits = _f16_to_f8e4m3_bits_numpy(x_f16)
223+
224+
indexes = f8_bits & 0x7F
225+
look_up_table = fns.from_numpy(F8E4M3_LUT, backend=x.backend)
226+
magnitude = look_up_table[indexes.astype(TensorDataType.int32)]
227+
sign = fns.where((f8_bits & 0x80) != 0, -1.0, 1.0)
228+
result = sign * magnitude
229+
return result.astype(TensorDataType.float32)

src/nncf/quantization/algorithms/weight_compression/openvino_backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,11 @@ def _create_compression_subgraph(
225225
precomputed_compressed_weight: Optional[CompressedWeight] = None,
226226
):
227227
compression_dtype = DTYPE_MAP[compression_config.compression_dtype]
228-
if compression_config.mode in [CompressWeightsMode.MXFP4, CompressWeightsMode.MXFP8_E4M3]:
229-
scale_dtype = ov.Type.f8e8m0
230-
else:
231-
scale_dtype = ov.Type.f16
228+
scale_dtype = (
229+
ov.Type.f8e8m0
230+
if compression_config.mode in [CompressWeightsMode.MXFP4, CompressWeightsMode.MXFP8_E4M3]
231+
else ov.Type.f16
232+
)
232233

233234
original_shape = weight.shape
234235

@@ -241,7 +242,6 @@ def _create_compression_subgraph(
241242
)
242243

243244
if compression_config.is_codebook:
244-
compression_dtype = DTYPE_MAP[compression_config.compression_dtype]
245245
converted_const = create_ov_codebook_subgraph(
246246
codebook=compressed_weight.codebook
247247
if compression_config.mode == CompressWeightsMode.CODEBOOK

src/nncf/quantization/algorithms/weight_compression/weight_lowering.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,7 @@ def _calculate_float_quantized_weight(norm_weight: Tensor, compression_dtype: Te
522522
if compression_dtype == TensorDataType.f8e4m3:
523523
from nncf.quantization.algorithms.weight_compression.fp8_conversion import fp32_to_fp8e4m3
524524

525-
quantiles_np = fp32_to_fp8e4m3(norm_weight.as_numpy_tensor().data)
526-
return fns.from_numpy(quantiles_np, backend=norm_weight.backend)
525+
return fp32_to_fp8e4m3(norm_weight)
527526

528527
is_nf4 = compression_dtype == TensorDataType.nf4
529528
quantiles_np = NF4_QUANTILES if is_nf4 else F4E2M1_QUANTILES

src/nncf/tensor/definitions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class TensorDataType(StrEnum):
5050
int32 = auto()
5151
int64 = auto()
5252
uint16 = auto()
53+
uint32 = auto()
5354
uint8 = auto()
5455
uint4 = auto()
5556
int4 = auto()
@@ -83,6 +84,7 @@ def itemsize(self) -> int:
8384
TensorDataType.int8: 8,
8485
TensorDataType.uint8: 8,
8586
TensorDataType.uint16: 16,
87+
TensorDataType.uint32: 32,
8688
TensorDataType.float16: 16,
8789
TensorDataType.bfloat16: 16,
8890
TensorDataType.float32: 32,

src/nncf/tensor/functions/numeric.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,17 @@ def astype(a: Tensor, dtype: TensorDataType) -> Tensor:
137137
"""
138138

139139

140+
@tensor_dispatcher
141+
def view(a: Tensor, dtype: TensorDataType) -> Tensor:
142+
"""
143+
Returns a view of the tensor with the specified data type.
144+
145+
:param a: The input tensor.
146+
:param dtype: The desired data
147+
:return: A view of the tensor with the specified data type.
148+
"""
149+
150+
140151
@tensor_dispatcher
141152
def dtype(a: Tensor) -> TensorDataType:
142153
"""

src/nncf/tensor/functions/numpy_numeric.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from nncf.tensor.tensor import TTensor
2828

2929
T_NUMPY_ARRAY = NDArray[Any]
30-
T_NUMPY = Union[T_NUMPY_ARRAY, np.generic] # type: ignore[type-arg]
30+
T_NUMPY = Union[T_NUMPY_ARRAY, np.generic]
3131

3232
DTYPE_MAP: dict[TensorDataType, DTypeLike] = {
3333
TensorDataType.float16: np.dtype(np.float16),
@@ -38,6 +38,7 @@
3838
TensorDataType.int64: np.dtype(np.int64),
3939
TensorDataType.uint8: np.dtype(np.uint8),
4040
TensorDataType.uint16: np.dtype(np.uint16),
41+
TensorDataType.uint32: np.dtype(np.uint32),
4142
}
4243

4344
DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()}
@@ -98,6 +99,11 @@ def _(a: T_NUMPY, dtype: TensorDataType) -> T_NUMPY:
9899
return a.astype(DTYPE_MAP[dtype])
99100

100101

102+
@numeric.view.register
103+
def _(a: T_NUMPY, dtype: TensorDataType) -> T_NUMPY:
104+
return a.view(DTYPE_MAP[dtype])
105+
106+
101107
@numeric.dtype.register
102108
def _(a: T_NUMPY) -> TensorDataType:
103109
return DTYPE_MAP_REV[np.dtype(a.dtype)]

src/nncf/tensor/functions/openvino_numeric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
TensorDataType.int32: ov.Type.i32,
3636
TensorDataType.int64: ov.Type.i64,
3737
TensorDataType.uint16: ov.Type.u16,
38+
TensorDataType.uint32: ov.Type.u32,
3839
TensorDataType.uint8: ov.Type.u8,
3940
TensorDataType.uint4: ov.Type.u4,
4041
TensorDataType.int4: ov.Type.i4,

src/nncf/tensor/functions/torch_numeric.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
TensorDataType.int64: torch.int64,
3737
TensorDataType.uint8: torch.uint8,
3838
TensorDataType.uint16: torch.uint16,
39+
TensorDataType.uint32: torch.uint32,
3940
}
4041

4142
DEVICE_MAP = {TensorDeviceType.CPU: "cpu", TensorDeviceType.GPU: "cuda"}
@@ -109,6 +110,11 @@ def _(a: torch.Tensor, dtype: TensorDataType) -> torch.Tensor:
109110
return a.type(DTYPE_MAP[dtype])
110111

111112

113+
@numeric.view.register
114+
def _(a: torch.Tensor, dtype: TensorDataType) -> torch.Tensor:
115+
return a.view(DTYPE_MAP[dtype])
116+
117+
112118
@numeric.dtype.register
113119
def _(a: torch.Tensor) -> TensorDataType:
114120
return DTYPE_MAP_REV[a.dtype]

0 commit comments

Comments
 (0)