Skip to content

Commit bd421c7

Browse files
Fix tests
1 parent b7d8a7b commit bd421c7

File tree

5 files changed

+67
-29
lines changed

5 files changed

+67
-29
lines changed

src/nncf/quantization/algorithms/weight_compression/fp8_conversion.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,29 @@
1111

1212
import numpy as np
1313

14-
14+
# fmt: off
1515
F8E4M3_LUT = np.array(
1616
[
17-
0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125, 0.009765625, 0.01171875, 0.013671875,
18-
0.015625, 0.017578125, 0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375, 0.029296875,
19-
0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875, 0.05078125, 0.0546875, 0.05859375,
20-
0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875,
21-
0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375,
22-
0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875,
23-
0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375,
24-
1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875,
25-
2.0, 2.25, 2.5, 2.75, 3.0, 3.25, 3.5, 3.75,
26-
4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5,
27-
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
28-
16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0,
29-
32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0,
30-
64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0,
31-
128.0, 144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0,
32-
256.0, 288.0, 320.0, 352.0, 384.0, 416.0, 448.0, np.nan,
17+
0.0, 0.001953125, 0.00390625, 0.005859375, 0.0078125, 0.009765625, 0.01171875, 0.013671875, # noqa
18+
0.015625, 0.017578125, 0.01953125, 0.021484375, 0.0234375, 0.025390625, 0.02734375, 0.029296875, # noqa
19+
0.03125, 0.03515625, 0.0390625, 0.04296875, 0.046875, 0.05078125, 0.0546875, 0.05859375, # noqa
20+
0.0625, 0.0703125, 0.078125, 0.0859375, 0.09375, 0.1015625, 0.109375, 0.1171875, # noqa
21+
0.125, 0.140625, 0.15625, 0.171875, 0.1875, 0.203125, 0.21875, 0.234375, # noqa
22+
0.25, 0.28125, 0.3125, 0.34375, 0.375, 0.40625, 0.4375, 0.46875, # noqa
23+
0.5, 0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, # noqa
24+
1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875, # noqa
25+
2.0, 2.25, 2.5, 2.75, 3.0, 3.25, 3.5, 3.75, # noqa
26+
4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, # noqa
27+
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, # noqa
28+
16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, # noqa
29+
32.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0, 60.0, # noqa
30+
64.0, 72.0, 80.0, 88.0, 96.0, 104.0, 112.0, 120.0, # noqa
31+
128.0, 144.0, 160.0, 176.0, 192.0, 208.0, 224.0, 240.0, # noqa
32+
256.0, 288.0, 320.0, 352.0, 384.0, 416.0, 448.0, np.nan, # noqa
3333
],
3434
dtype=np.float32,
3535
)
36+
# fmt: on
3637

3738

3839
def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
@@ -46,7 +47,6 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
4647
f16_m_size = 10
4748

4849
# f8 e4m3 layout
49-
f8e4m3_s_mask = 0x80
5050
f8e4m3_e_size = 4
5151
f8e4m3_e_mask = 0x78
5252
f8e4m3_e_bias = 7
@@ -57,9 +57,9 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
5757
byte_shift = 8
5858

5959
# f8 masks in uint16 domain
60-
f8_e_mask = f8e4m3_e_mask << byte_shift # 0x7800
61-
f8_m_mask = f8e4m3_m_mask << byte_shift # 0x0700
62-
f8_m_hidden_one_mask = 0x0800 # hidden 1 for subnormals
60+
f8_e_mask = f8e4m3_e_mask << byte_shift # 0x7800
61+
f8_m_mask = f8e4m3_m_mask << byte_shift # 0x0700
62+
f8_m_hidden_one_mask = 0x0800 # hidden 1 for subnormals
6363

6464
# rounding constants (same as C++)
6565
round_half = 0x01FF
@@ -79,7 +79,7 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
7979

8080
if f16_e_field == f16_e_mask:
8181
# f16 NaN / Inf -> f8 NaN (no Inf)
82-
f8_bits |= (f8e4m3_e_mask | f8e4m3_m_mask)
82+
f8_bits |= f8e4m3_e_mask | f8e4m3_m_mask
8383
elif f16_e_field != 0:
8484
# normalized f16
8585
f8_biased_exp = (f16_e_field >> f16_m_size) - (f16_e_bias - f8e4m3_e_bias)
@@ -97,12 +97,12 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
9797
# now set exponent & mantissa
9898
if f8_biased_exp > f8e4m3_e_max:
9999
# overflow -> NaN (no Inf)
100-
f8_bits |= (f8e4m3_e_mask | f8e4m3_m_mask)
100+
f8_bits |= f8e4m3_e_mask | f8e4m3_m_mask
101101
elif f8_biased_exp > 0:
102102
# normalized f8
103103
exp_field = (f8_biased_exp & (f8e4m3_e_mask >> f8e4m3_m_size)) << f8e4m3_m_size
104104
f8_bits |= exp_field
105-
f8_bits |= (fractional >> byte_shift)
105+
f8_bits |= fractional >> byte_shift
106106
else:
107107
# subnormal f8
108108
fractional = f8_m_hidden_one_mask | ((inp & f16_m_mask) << (f16_e_size - f8e4m3_e_size))
@@ -113,11 +113,14 @@ def _f16_to_f8e4m3_bits_scalar(h_bits: int) -> int:
113113

114114
fractional = 0 if f8_exp < f8_e_subnormal_min else (fractional >> (1 - f8_biased_exp))
115115

116-
if (((fractional & round_half) == round_odd and sticky == 0) or
117-
(fractional & round_norm) != 0 or sticky != 0):
116+
if (
117+
((fractional & round_half) == round_odd and sticky == 0)
118+
or (fractional & round_norm) != 0
119+
or sticky != 0
120+
):
118121
fractional += round_even
119122

120-
f8_bits |= (fractional >> byte_shift)
123+
f8_bits |= fractional >> byte_shift
121124
else:
122125
# f16 zero / subnormal -> sign + zero exponent/mantissa
123126
# (f8_bits already contains the sign)

src/nncf/tensor/functions/numpy_numeric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
TensorDataType.int32: np.dtype(np.int32),
3838
TensorDataType.int64: np.dtype(np.int64),
3939
TensorDataType.uint8: np.dtype(np.uint8),
40+
TensorDataType.uint16: np.dtype(np.uint16),
4041
}
4142

4243
DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()}

src/nncf/tensor/functions/torch_numeric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
TensorDataType.int32: torch.int32,
3636
TensorDataType.int64: torch.int64,
3737
TensorDataType.uint8: torch.uint8,
38+
TensorDataType.uint16: torch.uint16,
3839
}
3940

4041
DEVICE_MAP = {TensorDeviceType.CPU: "cpu", TensorDeviceType.GPU: "cuda"}

tests/cross_fw/test_templates/template_test_nncf_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,6 +2168,7 @@ def test_fn_eye(self, n, m, ref):
21682168
in [
21692169
TensorDataType.int4,
21702170
TensorDataType.uint4,
2171+
TensorDataType.uint16,
21712172
TensorDataType.nf4,
21722173
TensorDataType.f4e2m1,
21732174
TensorDataType.f8e8m0,

tests/openvino/native/quantization/test_weights_compression.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,38 @@ def test_int_compressed_weighs_range(mode, data):
14411441
8.0,
14421442
],
14431443
},
1444+
CompressWeightsMode.FP8_E4M3: {
1445+
"neg": [
1446+
-8.0,
1447+
-6.857143402099609,
1448+
-5.714285850524902,
1449+
-5.142857551574707,
1450+
-4.0,
1451+
-2.857142925262451,
1452+
-2.0,
1453+
-1.0,
1454+
0.0,
1455+
],
1456+
"pos": [0.0, 1.0, 2.0, 2.857142925262451, 4.0, 5.142857551574707, 5.714285850524902, 6.857143402099609, 8.0],
1457+
"neg-pos": [
1458+
-8.0,
1459+
-6.857143402099609,
1460+
-5.714285850524902,
1461+
-5.142857551574707,
1462+
-4.0,
1463+
-2.857142925262451,
1464+
-2.0,
1465+
-1.0,
1466+
0.0,
1467+
1.0,
1468+
2.0,
1469+
2.857142925262451,
1470+
4.0,
1471+
5.142857551574707,
1472+
5.714285850524902,
1473+
6.857143402099609,
1474+
],
1475+
},
14441476
}
14451477

14461478

@@ -1999,7 +2031,7 @@ def test_nf4_quantization_mid_quant(weight, scale):
19992031
scale = Tensor(scale)
20002032
# norm_weight equals -0.8480964 (one bit away from the first NF4 quantile center)
20012033
norm_weight = _calculate_normalized_weight(weight, scale)
2002-
nf4_quant = _calculate_float_quantized_weight(norm_weight, CompressWeightsMode.NF4)
2034+
nf4_quant = _calculate_float_quantized_weight(norm_weight, TensorDataType.nf4)
20032035

20042036
norm_weight_ov_backend = Tensor(ov.Tensor(norm_weight.data, norm_weight.shape, ov.Type.f32))
20052037
ref_nf4_quant = norm_weight_ov_backend.astype(TensorDataType.nf4).as_numpy_tensor()
@@ -2027,7 +2059,7 @@ def test_nf4_quantization_mid_quant(weight, scale):
20272059
)
20282060
def test_mxfp4_quantization_edge_cases(input_val, expected_val, description):
20292061
norm_weight = Tensor(np.array([input_val], dtype=np.float32))
2030-
result = _calculate_float_quantized_weight(norm_weight, CompressWeightsMode.MXFP4)
2062+
result = _calculate_float_quantized_weight(norm_weight, TensorDataType.f4e2m1)
20312063

20322064
assert result.data[0] == expected_val, (
20332065
f"{description}: Expected {expected_val}, got {result.data[0]} for input value {input_val}"

0 commit comments

Comments
 (0)