Skip to content

Commit e9e7cd0

Browse files
[OpenVINO] Optimized weight compression for FP4 mode (#3737)
### Changes Added optimized weight compression through OpenVINO models for FP4 compression mode. Results should be similar to MXFP4 (#3550). ### Reason for changes Improving UX. ### Tests Extended `tests/openvino/optimized_functions/test_compression_functions.py`
1 parent 109bf0a commit e9e7cd0

File tree

6 files changed

+88
-35
lines changed

6 files changed

+88
-35
lines changed

src/nncf/openvino/optimized_functions/functions.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,15 @@ def do_float_quantization(
108108
) -> tuple[Tensor, Tensor, Tensor]:
109109
"""
110110
Computes quantization scale if not provided, and performs corresponding float weight quantization.
111-
NF4 format uses 16 levels in [-1, 1] range, while MXFP4 uses 16 levels in [-6, 6].
111+
NF4 format uses 16 levels in [-1, 1] range, while FP4/MXFP4 uses 16 levels in [-6, 6].
112112
113113
:param weight: Weight array to compress.
114114
:param config: Weight compression configuration.
115115
:param reduction_axes: Axes, along which to reduce (collect) different statistics.
116116
:param precomputed_scale: Optional precomputed scale.
117-
:return: Returns quantized (for MXFP8_E4M3, FP4 and FP8_E4M3 normalized)
118-
weight tensor and corresponding scale tensor.
117+
:return: Returns quantized weight tensor and corresponding scale tensor.
119118
"""
120-
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]
119+
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]
121120

122121
weight_shape = weight.shape
123122
scale_shape = None if precomputed_scale is None else precomputed_scale.shape
@@ -129,7 +128,7 @@ def do_float_quantization(
129128
if weight.backend == TensorBackend.ov:
130129
# Return ov tensors in target precision to seamlessly insert them into openvino model later
131130
ov_model_params.return_ov_tensors = True
132-
weight_dtype = TensorDataType.f4e2m1 if config.mode == CompressWeightsMode.MXFP4 else TensorDataType.nf4
131+
weight_dtype = TensorDataType.nf4 if config.mode == CompressWeightsMode.NF4 else TensorDataType.f4e2m1
133132
ov_model_params.output_dtypes.update({"compressed_weight": weight_dtype})
134133

135134
model = get_float_quantization_model(
@@ -235,7 +234,7 @@ def float_quantize_dequantize_weight(
235234
:param return_compressed_weight: If True, besides decompressed weight will also return compressed weight and scale.
236235
:return: Dequantized weight tensor or a tuple containing the decompressed weight, compressed weight and scale.
237236
"""
238-
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]
237+
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]
239238

240239
# When reduction axes are not provided, assuming that the weights are already reshaped
241240
if config.group_size != -1 and reduction_axes is not None:

src/nncf/openvino/optimized_functions/models.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,7 @@ def get_float_quantization_model(
286286
reduction_axes: Optional[ReductionAxes] = None,
287287
) -> Union[ModelCallable, ModelAsNodes]:
288288
"""
289-
Get a model that compresses weights to float (currently nf4 or mxfp4) destination type using the given
290-
configuration.
289+
Get a model that compresses weights to float destination type using the given configuration.
291290
292291
:param ov_model_params: OV model parameters.
293292
:param config: Compression configuration.
@@ -319,7 +318,7 @@ def get_float_quantize_dequantize_weight_model(
319318
return_compressed_weight: Optional[bool] = False,
320319
) -> ModelCallable:
321320
"""
322-
Get a model that performs float (currently only nf4) compression and decompression of the given weight.
321+
Get a model that performs float compression and decompression of the given weight.
323322
324323
:param ov_model_params: OV model parameters.
325324
:param config: Compression configuration.
@@ -572,7 +571,7 @@ def _build_float_quantization_model(
572571
reduction_axes: Optional[ReductionAxes] = None,
573572
return_nodes: bool = False,
574573
) -> Union[ModelCallable, ModelAsNodes]:
575-
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]
574+
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]
576575

577576
default_input_dtypes = {"scale": TensorDataType.float32}
578577
default_output_dtypes = {"compressed_weight": TensorDataType.float32, "scale": TensorDataType.float32}
@@ -626,8 +625,15 @@ def _build_float_quantization_model(
626625
eps = np.finfo(np.float32).eps
627626
scale = opset.select(opset.less(opset.abs(scale), eps), eps, scale)
628627

628+
# Equals 1.0 for NF4
629+
FP_MAX_VALS = {
630+
CompressWeightsMode.MXFP4: 6.0,
631+
CompressWeightsMode.FP4: 6.0,
632+
}
633+
if config.mode in FP_MAX_VALS:
634+
scale = divide_op(scale, opset.constant(FP_MAX_VALS[config.mode], ov.Type.f32))
635+
629636
if config.mode == CompressWeightsMode.MXFP4:
630-
scale = scale / opset.constant(6.0, ov.Type.f32)
631637
scale = opset.log(scale) / opset.log(opset.constant(2.0, ov.Type.f32))
632638
scale = opset.ceil(scale)
633639
scale = opset.clamp(scale, -127.0, 127.0)

src/nncf/quantization/algorithms/weight_compression/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
dtype=np.float32,
3434
)
3535

36-
MXFP4_QUANTILES = np.array(
36+
F4E2M1_QUANTILES = np.array(
3737
[
3838
-6.0,
3939
-4.0,
@@ -100,4 +100,4 @@
100100
)
101101

102102

103-
CENTER_OF_MXFP4_QUANTILES = (MXFP4_QUANTILES[1:] + MXFP4_QUANTILES[:-1]) / 2
103+
CENTER_OF_F4E2M1_QUANTILES = (F4E2M1_QUANTILES[1:] + F4E2M1_QUANTILES[:-1]) / 2

src/nncf/quantization/algorithms/weight_compression/weight_lowering.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from nncf.errors import UnsupportedModelError
2020
from nncf.parameters import CompressWeightsMode
2121
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
22-
from nncf.quantization.algorithms.weight_compression.constants import CENTER_OF_MXFP4_QUANTILES
22+
from nncf.quantization.algorithms.weight_compression.constants import CENTER_OF_F4E2M1_QUANTILES
2323
from nncf.quantization.algorithms.weight_compression.constants import CENTER_OF_NF4_QUANTILES
24-
from nncf.quantization.algorithms.weight_compression.constants import MXFP4_QUANTILES
24+
from nncf.quantization.algorithms.weight_compression.constants import F4E2M1_QUANTILES
2525
from nncf.quantization.algorithms.weight_compression.constants import NF4_QUANTILES
2626
from nncf.quantization.algorithms.weight_compression.parameters import CompressedWeight
2727
from nncf.quantization.fake_quantize import calculate_scale_zero_point
@@ -32,6 +32,16 @@
3232

3333
ReductionAxes = Union[int, tuple[int, ...]]
3434

35+
36+
OPTIMIZED_COMPRESSION_COMPATIBLE_MODES = (
37+
CompressWeightsMode.INT8_ASYM,
38+
CompressWeightsMode.INT8_SYM,
39+
CompressWeightsMode.INT4_ASYM,
40+
CompressWeightsMode.INT4_SYM,
41+
CompressWeightsMode.NF4,
42+
CompressWeightsMode.MXFP4,
43+
CompressWeightsMode.FP4,
44+
)
3545
MIN_INPUT_SIZE_FOR_OPTIMIZED_COMPRESSION = 10000
3646

3747

@@ -168,7 +178,7 @@ def do_float_quantization(
168178
weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size)
169179

170180
# Optimized implementation
171-
if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4] and _can_run_optimized(weight):
181+
if _can_run_optimized(weight, config.mode):
172182
from nncf.openvino.optimized_functions import do_float_quantization as do_float_quantization_ov
173183

174184
return do_float_quantization_ov(weight, config, reduction_axes, precomputed_scale)
@@ -183,7 +193,7 @@ def do_float_quantization(
183193
if scale is None:
184194
scale = calculate_float_quantization_params(weight, reduction_axes, config)
185195
norm_weight = _calculate_normalized_weight(weight, scale)
186-
if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]:
196+
if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]:
187197
if original_weight_backend == TensorBackend.ov:
188198
# Can convert through OpenVINO and return OpenVINO-native nf4/f4e2m1 tensor
189199
target_dtype = TensorDataType.nf4 if config.mode == CompressWeightsMode.NF4 else TensorDataType.f4e2m1
@@ -209,7 +219,7 @@ def float_quantize_dequantize_weight(
209219
) -> Union[Tensor, tuple[Tensor, Tensor, Tensor]]:
210220
"""
211221
First quantizes the given weight tensor to float dtype and then dequantizes it back to obtain float32 values.
212-
MXFP8_E4M3, FP8_E4M3 and FP4 modes currently are not supported.
222+
MXFP8_E4M3 and FP8_E4M3 modes currently are not supported.
213223
214224
:param weight: The weight tensor to quantize-dequantize.
215225
:param config: Compression configuration.
@@ -221,12 +231,13 @@ def float_quantize_dequantize_weight(
221231
assert config.mode in [
222232
CompressWeightsMode.NF4,
223233
CompressWeightsMode.MXFP4,
234+
CompressWeightsMode.FP4,
224235
CompressWeightsMode.CODEBOOK,
225236
CompressWeightsMode.CB4_F8E4M3,
226237
]
227238

228239
# Optimized implementation
229-
if config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4] and _can_run_optimized(weight):
240+
if _can_run_optimized(weight, config.mode):
230241
from nncf.openvino.optimized_functions import (
231242
float_quantize_dequantize_weight as float_quantize_dequantize_weight_ov,
232243
)
@@ -302,7 +313,7 @@ def get_integer_quantization_error(
302313
:return: The quantity characterizing the error of integer quantization.
303314
"""
304315
# Optimized implementation
305-
if _can_run_optimized(weight):
316+
if _can_run_optimized(weight, config.mode):
306317
from nncf.openvino.optimized_functions import (
307318
get_integer_quantization_error as get_integer_quantization_error_ov,
308319
)
@@ -439,7 +450,7 @@ def do_integer_quantization(
439450
weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, config.group_size)
440451

441452
# Optimized implementation
442-
if _can_run_optimized(weight):
453+
if _can_run_optimized(weight, config.mode):
443454
from nncf.openvino.optimized_functions import do_integer_quantization as do_integer_quantization_ov
444455

445456
return do_integer_quantization_ov(weight, config, reduction_axes, precomputed_scale, precomputed_zero_point)
@@ -488,7 +499,7 @@ def integer_quantize_dequantize_weight(
488499
(and zero point).
489500
"""
490501
# Optimized implementation
491-
if _can_run_optimized(weight):
502+
if _can_run_optimized(weight, config.mode):
492503
from nncf.openvino.optimized_functions import (
493504
integer_quantize_dequantize_weight as integer_quantize_dequantize_weight_ov,
494505
)
@@ -520,14 +531,14 @@ def _calculate_float_quantized_weight(norm_weight: Tensor, mode: CompressWeights
520531
:param norm_weight: Normalized weight tensor to quantize.
521532
:return: Tensor with floating-point values, where each of them corresponds to 1 out of 16 quants.
522533
"""
523-
assert mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]
524-
quantiles_np = NF4_QUANTILES if mode == CompressWeightsMode.NF4 else MXFP4_QUANTILES
525-
quantile_centers_np = CENTER_OF_NF4_QUANTILES if mode == CompressWeightsMode.NF4 else CENTER_OF_MXFP4_QUANTILES
534+
assert mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]
535+
quantiles_np = NF4_QUANTILES if mode == CompressWeightsMode.NF4 else F4E2M1_QUANTILES
536+
quantile_centers_np = CENTER_OF_NF4_QUANTILES if mode == CompressWeightsMode.NF4 else CENTER_OF_F4E2M1_QUANTILES
526537
quantile_centers = fns.from_numpy(quantile_centers_np, backend=norm_weight.backend)
527538
indexes = fns.searchsorted(quantile_centers, norm_weight)
528539
quantiles = fns.from_numpy(quantiles_np, backend=indexes.backend)
529540

530-
if mode == CompressWeightsMode.MXFP4:
541+
if mode in [CompressWeightsMode.MXFP4, CompressWeightsMode.FP4]:
531542
# If in-between two quantiles, round to the nearest even quantile.
532543
shifted_indexes = fns.clip(indexes + 1, 0, quantiles.size - 1)
533544
dist_left = fns.abs(norm_weight - quantiles[indexes])
@@ -639,11 +650,12 @@ def _calculate_integer_quantized_weight(
639650
return compressed_weights
640651

641652

642-
def _can_run_optimized(inp: Tensor) -> bool:
653+
def _can_run_optimized(inp: Tensor, mode: CompressWeightsMode) -> bool:
643654
if (
644655
inp.backend in [TensorBackend.ov, TensorBackend.numpy]
645656
and inp.size >= MIN_INPUT_SIZE_FOR_OPTIMIZED_COMPRESSION
646657
and os.environ.get("NNCF_DISABLE_OPTIMIZED_COMPRESSION") is None
658+
and mode in OPTIMIZED_COMPRESSION_COMPATIBLE_MODES
647659
):
648660
if is_openvino_available():
649661
from nncf.openvino.cpu_info import is_arm_cpu

tests/openvino/native/quantization/test_weights_compression.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,39 @@ def test_int_compressed_weighs_range(mode, data):
14081408
"neg": [-8.0, -8.0, -6.0, -4.0, -4.0, -3.0, -2.0, -1.0, -0.0],
14091409
"pos": [-0.0, 1.0, 2.0, 3.0, 4.0, 4.0, 6.0, 8.0, 8.0],
14101410
"neg-pos": [-8.0, -8.0, -6.0, -4.0, -4.0, -3.0, -2.0, -1.0, -0.0, 1.0, 2.0, 3.0, 4.0, 4.0, 6.0, 8.0],
1411-
}
1411+
},
1412+
CompressWeightsMode.FP4: {
1413+
"neg": [
1414+
-8.0,
1415+
-8.0,
1416+
-5.333333492279053,
1417+
-5.333333492279053,
1418+
-4.0,
1419+
-2.6666667461395264,
1420+
-2.0,
1421+
-1.3333333730697632,
1422+
-0.0,
1423+
],
1424+
"pos": [-0.0, 1.3333333730697632, 2.0, 2.6666667461395264, 4.0, 5.333333492279053, 5.333333492279053, 8.0, 8.0],
1425+
"neg-pos": [
1426+
-8.0,
1427+
-8.0,
1428+
-5.333333492279053,
1429+
-5.333333492279053,
1430+
-4.0,
1431+
-2.6666667461395264,
1432+
-2.0,
1433+
-1.3333333730697632,
1434+
-0.0,
1435+
1.3333333730697632,
1436+
2.0,
1437+
2.6666667461395264,
1438+
4.0,
1439+
5.333333492279053,
1440+
5.333333492279053,
1441+
8.0,
1442+
],
1443+
},
14121444
}
14131445

14141446

tests/openvino/optimized_functions/test_compression_functions.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ class QuantizationTask(Enum):
6868

6969
FP4_COMPRESSION_CONFIGS = [
7070
WeightCompressionConfig(CompressWeightsMode.NF4),
71+
WeightCompressionConfig(CompressWeightsMode.FP4),
7172
WeightCompressionConfig(CompressWeightsMode.NF4, group_size=2),
73+
WeightCompressionConfig(CompressWeightsMode.FP4, group_size=2),
7274
WeightCompressionConfig(CompressWeightsMode.MXFP4, group_size=32),
7375
]
7476

@@ -360,14 +362,16 @@ def get_input_node_data(node: ov.Node, input_id: int) -> Tensor:
360362
or compression_kwargs.get("lora_correction")
361363
)
362364

363-
if config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM, CompressWeightsMode.MXFP4]:
364-
if config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM] and weight_dtype in [
365-
TensorDataType.f8e4m3,
366-
TensorDataType.f8e5m2,
367-
]:
365+
if is_data_aware and config.mode in [
366+
CompressWeightsMode.INT8_ASYM,
367+
CompressWeightsMode.INT8_SYM,
368+
CompressWeightsMode.MXFP4,
369+
CompressWeightsMode.FP4,
370+
]:
371+
pytest.skip("Data-aware compression is not supported for INT8, MXFP4, FP4 modes.")
372+
if config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]:
373+
if weight_dtype in [TensorDataType.f8e4m3, TensorDataType.f8e5m2]:
368374
pytest.skip("INT8 compression is not supported for f8 dtypes.")
369-
if is_data_aware:
370-
pytest.skip("Data-aware compression is not supported for INT8 or MXFP4 modes.")
371375
else:
372376
compression_kwargs["all_layers"] = True
373377

0 commit comments

Comments
 (0)