Skip to content

Commit c9ff040

Browse files
Propagate OV*QuantizationConfig kwargs to nncf calls (#1179)
* Initial commit * Add docs * Remove explicit init_kwargs * Update docs/source/openvino/optimization.mdx Co-authored-by: Alexander Kozlov <[email protected]> * init_kwargs -> kwargs * Fix test after merging --------- Co-authored-by: Alexander Kozlov <[email protected]>
1 parent 727b6ce commit c9ff040

File tree

4 files changed

+171
-31
lines changed

4 files changed

+171
-31
lines changed

docs/source/openvino/optimization.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ quantization_config = OVWeightQuantizationConfig(
8787
)
8888
```
8989

90+
Note: `OVWeightQuantizationConfig` also accepts keyword arguments that are not listed in its constructor. In this case such arguments will be passed directly to `nncf.compress_weights()` call. This is useful for passing additional parameters to the quantization algorithm.
91+
9092
By default the quantization scheme will be [asymmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#asymmetric-quantization), to make it [symmetric](https://github.com/openvinotoolkit/nncf/blob/develop/docs/usage/training_time_compression/other_algorithms/LegacyQuantization.md#symmetric-quantization) you can add `sym=True`.
9193

9294
For 4-bit quantization you can also specify the following arguments in the quantization configuration :

optimum/intel/openvino/configuration.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def __init__(
294294
dataset: Optional[Union[str, List[str]]] = None,
295295
tokenizer: Optional[str] = None,
296296
processor: Optional[str] = None,
297-
trust_remote_code: bool = False,
297+
trust_remote_code: Optional[bool] = False,
298298
**kwargs,
299299
):
300300
"""
@@ -323,6 +323,7 @@ def __init__(
323323
if isinstance(ignored_scope, nncf.IgnoredScope):
324324
ignored_scope = ignored_scope.__dict__
325325
self.ignored_scope = ignored_scope
326+
self.kwargs = kwargs
326327

327328
def post_init(self):
328329
try:
@@ -342,6 +343,12 @@ def get_ignored_scope_instance(self) -> "nncf.IgnoredScope":
342343
def clone(self):
343344
return copy.deepcopy(self)
344345

346+
def to_dict(self) -> Dict[str, Any]:
347+
# Unpack kwargs dict
348+
result = super().to_dict()
349+
result = result | result.pop("kwargs", {})
350+
return result
351+
345352

346353
@dataclass
347354
class OVWeightQuantizationConfig(OVQuantizationConfigBase):
@@ -427,6 +434,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
427434
retained in their original precision without any quantization.
428435
- "int8_sym" stands for 8-bit integer symmetric quantization without zero point.
429436
- "int8_asym" stands for 8-bit integer asymmetric quantization with zero points per each quantization group.
437+
kwargs: Additional parameters for nncf.compress_weights() call.
430438
"""
431439

432440
def __init__(
@@ -451,13 +459,21 @@ def __init__(
451459
backup_precision: Optional[str] = None,
452460
**kwargs,
453461
):
462+
weight_format = kwargs.pop("weight_format", None)
463+
if weight_format is not None:
464+
logger.warning(
465+
"The `weight_format` parameter is deprecated and will be removed in optimum-intel v1.24.0. "
466+
"Please use `dtype` instead."
467+
)
468+
dtype = weight_format
454469
super().__init__(
455470
ignored_scope=ignored_scope,
456471
num_samples=num_samples,
457472
dataset=dataset,
458473
tokenizer=tokenizer,
459474
processor=processor,
460475
trust_remote_code=trust_remote_code,
476+
**kwargs,
461477
)
462478
self.bits = bits
463479
self.sym = sym
@@ -470,12 +486,6 @@ def __init__(
470486
self.gptq = gptq
471487
self.lora_correction = lora_correction
472488
self.backup_precision = backup_precision
473-
if kwargs.get("weight_format") is not None:
474-
logger.warning(
475-
"The `weight_format` parameter is deprecated and will be removed in optimum-intel v1.24.0. "
476-
"Please use `dtype` instead."
477-
)
478-
dtype = kwargs.get("weight_format")
479489
self.dtype = dtype
480490
self.post_init()
481491

@@ -624,6 +634,7 @@ def to_nncf_dict(self) -> Dict[str, Any]:
624634
"gptq": self.gptq,
625635
"lora_correction": self.lora_correction,
626636
"backup_mode": backup_mode,
637+
**self.kwargs,
627638
}
628639
return result
629640

@@ -712,27 +723,30 @@ def __init__(
712723
reduces quantization error.
713724
dtype (`str`, defaults to "int8"):
714725
Data type activations are compressed to. Possible values: ['int8', 'f8e4m3', 'f8e5m2'].
726+
kwargs: Additional parameters for nncf.quantize() call.
715727
"""
728+
activation_format = kwargs.pop("activation_format", None)
729+
if activation_format is not None:
730+
logger.warning(
731+
"The `activation_format` parameter is deprecated and will be removed in optimum-intel v1.24.0. "
732+
"Please use `dtype` instead."
733+
)
734+
dtype = activation_format
716735
super().__init__(
717736
ignored_scope=ignored_scope,
718737
num_samples=num_samples,
719738
dataset=dataset,
720739
tokenizer=tokenizer,
721740
processor=processor,
722741
trust_remote_code=trust_remote_code,
742+
**kwargs,
723743
)
724744
self.bits = bits
725745
self.sym = sym
726746
self.model_type = model_type
727747
self.fast_bias_correction = fast_bias_correction
728748
self.overflow_fix = overflow_fix
729749
self.smooth_quant_alpha = smooth_quant_alpha
730-
if kwargs.get("activation_format") is not None:
731-
logger.warning(
732-
"The `activation_format` parameter is deprecated and will be removed in optimum-intel v1.24.0. "
733-
"Please use `dtype` instead."
734-
)
735-
dtype = kwargs.get("activation_format")
736750
self.dtype = dtype
737751

738752
f8_dtypes = ["f8e4m3", "f8e5m2"]
@@ -769,23 +783,19 @@ def to_nncf_dict(self) -> Dict[str, Any]:
769783
Returns a dictionary with the variables that are ready to use for nncf.compress_weights() call.
770784
"""
771785

772-
preset = "performance" if self.sym else "mixed"
773-
advanced_parameters_dict = {"overflow_fix": self.overflow_fix}
786+
# Merge advanced parameters from kwargs if they were provided
787+
kwargs_copy = copy.deepcopy(self.kwargs)
788+
advanced_parameters = kwargs_copy.pop("advanced_parameters", nncf.AdvancedQuantizationParameters())
789+
advanced_parameters.overflow_fix = nncf.OverflowFix(self.overflow_fix)
774790
if self.smooth_quant_alpha:
775-
advanced_parameters_dict["smooth_quant_alphas"] = {"matmul": self.smooth_quant_alpha}
791+
advanced_parameters.smooth_quant_alphas.matmul = self.smooth_quant_alpha
776792

777793
mode_map = {"f8e4m3": "fp8_e4m3", "f8e5m2": "fp8_e5m2"}
778794
mode = mode_map.get(self.dtype)
779795

796+
preset = "performance" if self.sym else "mixed"
780797
preset = nncf.QuantizationPreset(preset)
781798
model_type = nncf.ModelType(self.model_type)
782-
advanced_parameters = nncf.AdvancedQuantizationParameters(
783-
overflow_fix=advanced_parameters_dict["overflow_fix"],
784-
)
785-
if "smooth_quant_alphas" in advanced_parameters_dict:
786-
advanced_parameters.smooth_quant_alphas = nncf.AdvancedSmoothQuantParameters(
787-
**advanced_parameters_dict["smooth_quant_alphas"]
788-
)
789799

790800
return {
791801
"mode": mode,
@@ -795,6 +805,7 @@ def to_nncf_dict(self) -> Dict[str, Any]:
795805
"model_type": model_type,
796806
"ignored_scope": self.get_ignored_scope_instance(),
797807
"advanced_parameters": advanced_parameters,
808+
**kwargs_copy,
798809
}
799810

800811

@@ -965,6 +976,7 @@ def __init__(
965976
tokenizer=tokenizer,
966977
processor=processor,
967978
trust_remote_code=trust_remote_code,
979+
**kwargs,
968980
)
969981

970982
self.post_init()

optimum/intel/openvino/quantization.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,8 +1021,18 @@ def _weight_only_quantization(
10211021
else:
10221022
dataset = nncf.Dataset(calibration_dataset)
10231023

1024-
wc_kwargs = copy.deepcopy(kwargs)
1025-
wc_kwargs.update(config.to_nncf_dict())
1024+
wc_kwargs = config.to_nncf_dict()
1025+
1026+
# Arguments provided in kwargs override the ones from the config
1027+
kwargs_intersection = set(wc_kwargs.keys()) & set(kwargs.keys())
1028+
if kwargs_intersection:
1029+
logger.warning(
1030+
f"The following nncf.compress_weights() arguments from the OVWeightQuantizationConfig will be overridden "
1031+
f"by the ones given in _weight_only_quantization call kwargs: {kwargs_intersection}."
1032+
)
1033+
wc_kwargs.update(kwargs)
1034+
wc_kwargs.pop("weight_only", None)
1035+
10261036
compressed_model = nncf.compress_weights(
10271037
model,
10281038
dataset=dataset,
@@ -1048,8 +1058,19 @@ def _full_quantization(
10481058

10491059
if verify_not_optimized:
10501060
_verify_not_optimized(model)
1051-
q_kwargs = copy.deepcopy(kwargs)
1052-
q_kwargs.update(quantization_config.to_nncf_dict())
1061+
1062+
q_kwargs = quantization_config.to_nncf_dict()
1063+
1064+
# Arguments provided in kwargs override the ones from the config
1065+
kwargs_intersection = set(q_kwargs.keys()) & set(kwargs.keys())
1066+
if kwargs_intersection:
1067+
logger.warning(
1068+
f"The following nncf.quantize() arguments from the OVQuantizationConfig will be overridden "
1069+
f"by the ones given in _full_quantization call kwargs: {kwargs_intersection}."
1070+
)
1071+
q_kwargs.update(kwargs)
1072+
q_kwargs.pop("weight_only", None)
1073+
10531074
quantized_model = nncf.quantize(model, calibration_dataset=calibration_dataset, **q_kwargs)
10541075

10551076
_remove_f16_kv_cache_precision_flag(quantized_model)

tests/openvino/test_quantization.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import dataclasses
1415
import inspect
1516

1617
# ruff: noqa
1718

1819
import itertools
1920
import logging
2021
import unittest
21-
from collections import defaultdict
22+
from collections import defaultdict, Iterable
2223
from enum import Enum
2324
from functools import partial
24-
from typing import Union
25+
from typing import Union, Type
2526

2627
import openvino as ov
2728
import pytest
@@ -77,7 +78,7 @@
7778
from optimum.intel.openvino.utils import TemporaryDirectory
7879
from copy import deepcopy
7980

80-
from optimum.intel.openvino.quantization import InferRequestWrapper
81+
from optimum.intel.openvino.quantization import InferRequestWrapper, _weight_only_quantization, _full_quantization
8182
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version
8283
from utils_tests import (
8384
MODEL_NAMES,
@@ -1211,7 +1212,6 @@ class OVQuantizationConfigTest(unittest.TestCase):
12111212
),
12121213
),
12131214
(OVQuantizationConfig(ignored_scope=nncf.IgnoredScope(names=["op_name"])),),
1214-
(OVDynamicQuantizationConfig(bits=8, sym=True),),
12151215
)
12161216

12171217
QUANTIZATION_CONFIG_DICTS = (
@@ -1276,6 +1276,60 @@ class OVQuantizationConfigTest(unittest.TestCase):
12761276
(dict(bits=8, fast_bias_correction=True, weight_only=False), OVQuantizationConfig, None),
12771277
)
12781278

1279+
QUANTIZATION_CONFIGS_WITH_KWARGS = (
1280+
(
1281+
OVWeightQuantizationConfig,
1282+
{
1283+
"advanced_parameters": nncf.AdvancedCompressionParameters(statistics_path="statistics_path"),
1284+
"some_arg": "some_value",
1285+
},
1286+
{
1287+
"advanced_parameters": nncf.AdvancedCompressionParameters(statistics_path="statistics_path"),
1288+
"some_arg": "some_value",
1289+
},
1290+
),
1291+
(
1292+
OVQuantizationConfig,
1293+
{
1294+
"advanced_parameters": nncf.AdvancedQuantizationParameters(disable_channel_alignment=True),
1295+
"some_arg": "some_value",
1296+
},
1297+
{
1298+
"advanced_parameters": nncf.AdvancedQuantizationParameters(
1299+
overflow_fix=nncf.OverflowFix.DISABLE,
1300+
disable_channel_alignment=True,
1301+
),
1302+
"some_arg": "some_value",
1303+
},
1304+
),
1305+
(
1306+
OVQuantizationConfig,
1307+
{
1308+
"advanced_parameters": nncf.AdvancedQuantizationParameters(overflow_fix=nncf.OverflowFix.ENABLE),
1309+
},
1310+
{
1311+
"advanced_parameters": nncf.AdvancedQuantizationParameters(
1312+
overflow_fix=nncf.OverflowFix.DISABLE,
1313+
),
1314+
},
1315+
),
1316+
(
1317+
OVQuantizationConfig,
1318+
{
1319+
"smooth_quant_alpha": 0.5,
1320+
"advanced_parameters": nncf.AdvancedQuantizationParameters(
1321+
smooth_quant_alphas=nncf.AdvancedSmoothQuantParameters(matmul=0.7, convolution=0.7),
1322+
),
1323+
},
1324+
{
1325+
"advanced_parameters": nncf.AdvancedQuantizationParameters(
1326+
overflow_fix=nncf.OverflowFix.DISABLE,
1327+
smooth_quant_alphas=nncf.AdvancedSmoothQuantParameters(matmul=0.5, convolution=0.7),
1328+
),
1329+
},
1330+
),
1331+
)
1332+
12791333
def get_default_configurations() -> dict:
12801334
default_configurations = deepcopy(_DEFAULT_4BIT_CONFIGS)
12811335
default_configurations.update({"default": _DEFAULT_4BIT_CONFIG})
@@ -1327,6 +1381,57 @@ def test_for_no_short_id_duplicates(self):
13271381
assert short_id not in short_ids
13281382
short_ids.add(short_id)
13291383

1384+
@parameterized.expand(QUANTIZATION_CONFIGS_WITH_KWARGS)
1385+
def test_config_init_kwargs(
1386+
self,
1387+
config_type: Type[Union[OVWeightQuantizationConfig, OVQuantizationConfig]],
1388+
config_kwargs: dict,
1389+
ref_nncf_dict: dict,
1390+
):
1391+
nncf_dict = config_type(**config_kwargs).to_nncf_dict()
1392+
ref_nncf_dict = config_type().to_nncf_dict() | ref_nncf_dict
1393+
self.assertTrue(self.compare_objects(nncf_dict, ref_nncf_dict))
1394+
1395+
@parameterized.expand(
1396+
[
1397+
("nncf.compress_weights", "_weight_only_quantization", "dataset", OVWeightQuantizationConfig),
1398+
("nncf.quantize", "_full_quantization", "calibration_dataset", OVQuantizationConfig),
1399+
]
1400+
)
1401+
def test_quantization_kwargs_override(self, mock_method_name, quantization_function, dataset_key, config_type):
1402+
with unittest.mock.patch(mock_method_name) as mock_method:
1403+
mock_model = unittest.mock.Mock([])
1404+
mock_model.get_rt_info = unittest.mock.Mock(return_value={})
1405+
1406+
mock_quantization_config = unittest.mock.Mock(config_type)
1407+
mock_quantization_config.to_nncf_dict.return_value = {"param1": "value1", "param2": "value2"}
1408+
1409+
additional_kwargs = {"param2": "new_value2", "param3": "value3"}
1410+
1411+
quantization_function = globals()[quantization_function]
1412+
quantization_function(mock_model, mock_quantization_config, None, **additional_kwargs)
1413+
1414+
expected_kwargs = {"param1": "value1", "param2": "new_value2", "param3": "value3", dataset_key: None}
1415+
1416+
mock_method.assert_called_once_with(mock_model, **expected_kwargs)
1417+
1418+
@staticmethod
1419+
def compare_objects(o1, o2) -> bool:
1420+
if dataclasses.is_dataclass(o1) and dataclasses.is_dataclass(o2):
1421+
o1 = o1.__dict__
1422+
o2 = o2.__dict__
1423+
if isinstance(o1, dict) and isinstance(o2, dict):
1424+
for k in set(o1.keys()) | set(o2.keys()):
1425+
if not OVQuantizationConfigTest.compare_objects(o1[k], o2[k]):
1426+
return False
1427+
return True
1428+
if isinstance(o1, Iterable) and isinstance(o2, Iterable) and not (isinstance(o1, str) or isinstance(o2, str)):
1429+
for it1, it2 in zip(o1, o2):
1430+
if not OVQuantizationConfigTest.compare_objects(it1, it2):
1431+
return False
1432+
return True
1433+
return o1 == o2
1434+
13301435

13311436
class InferRequestWrapperTest(unittest.TestCase):
13321437
MODEL_NAME = ("whisper",)

0 commit comments

Comments
 (0)