Skip to content

Commit 766d41f

Browse files
committed
Add tests for positional weights quantization config suuport
1 parent 8952be0 commit 766d41f

File tree

2 files changed

+2
-5
lines changed

2 files changed

+2
-5
lines changed

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def __init__(self, qc: QuantizationConfig,
398398
if isinstance(attr, int):
399399
# this is a positional attribute, so it needs to be handled separately.
400400
# Search for any keys in the op config's attribute weight config mapping that contain the
401-
# POS_ATTR string.If none are found, it indicates that no specific quantization config is defined for
401+
# POS_ATTR string. If none are found, it indicates that no specific quantization config is defined for
402402
# positional weights, so the default config will be used instead.
403403
attrs_included_in_name = {k: v for k, v in op_cfg.attr_weights_configs_mapping.items() if
404404
POS_ATTR in k}

tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,7 @@ def _create_node_weights_op_cfg(
6262
Returns:
6363
OpQuantizationConfig: Class to configure the quantization parameters of an operator.
6464
"""
65-
attr_weights_configs_mapping = {}
66-
67-
for attr, attr_config in zip(pos_weight_attr, pos_weight_attr_config):
68-
attr_weights_configs_mapping[attr] = attr_config
65+
attr_weights_configs_mapping = dict(zip(pos_weight_attr, pos_weight_attr_config))
6966

7067
op_cfg = OpQuantizationConfig(
7168
default_weight_attr_config=def_weight_attr_config,

0 commit comments

Comments
 (0)