Skip to content

Commit cef47e1

Browse files
irenabirenab
authored andcommitted
update setting node configs attributes
1 parent d2b40fe commit cef47e1

9 files changed

Lines changed: 304 additions & 220 deletions

File tree

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 85 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,13 @@ class BaseNodeQuantizationConfig(object):
4646
Base class for node quantization configuration
4747
"""
4848

49-
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
50-
*args: List[Any], **kwargs: Dict[str, Any]):
49+
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any):
5150
"""
5251
Changes a BaseNodeQuantizationConfig's parameter.
53-
Note that arg and kwargs are only to allow clean override in the child classes.
5452
5553
Args:
5654
config_parameter_name: parameter name to change.
5755
config_parameter_value: parameter value to change.
58-
args: A list of additional arguments.
59-
kwargs: A dictionary with additional key arguments.
60-
6156
"""
6257
if hasattr(self, config_parameter_name):
6358
setattr(self, config_parameter_name, config_parameter_value)
@@ -77,6 +72,12 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
7772
"""
7873
Attributes for configuring the quantization of the activations of a node.
7974
"""
75+
_no_cfg_modes = [
76+
ActivationQuantizationMode.NO_QUANT,
77+
ActivationQuantizationMode.FLN_NO_QUANT,
78+
ActivationQuantizationMode.PRESERVE_QUANT
79+
]
80+
8081
def __init__(self, op_cfg: OpQuantizationConfig):
8182
"""
8283
@@ -104,18 +105,20 @@ def __init__(self, op_cfg: OpQuantizationConfig):
104105
# Z-threshold is a global param from QuantizationConfig, however it can be overridden per node by NetworkEditor.
105106
# Since activation qparams are re-computed in several places, it's easier to keep it here and update it once.
106107
self.z_threshold = None
107-
# Z-threshold is a global param from QuantizationConfig, however it can be overridden per node by NetworkEditor.
108-
# Since activation qparams are re-computed in several places, it's easier to keep it here and update it once.
109-
self.z_threshold = None
110108

111109
def set_quant_mode(self, quant_mode: ActivationQuantizationMode):
112-
""" Sets quantization mode. If quantization is off, resets config attributes to None. """
110+
"""
111+
Set quantization mode. If no configuration is associated with the quant_mode, it's un-set.
112+
113+
Args:
114+
quant_mode: quantization mode to set.
115+
"""
116+
if quant_mode in [ActivationQuantizationMode.QUANT, ActivationQuantizationMode.FLN_QUANT]:
117+
if self.quant_mode in self._no_cfg_modes:
118+
raise ValueError(f'Cannot change quant_mode to {quant_mode.name} from {self.quant_mode.name}.')
113119
self._quant_mode = quant_mode
114-
if quant_mode in [ActivationQuantizationMode.NO_QUANT, ActivationQuantizationMode.FLN_NO_QUANT,
115-
ActivationQuantizationMode.PRESERVE_QUANT]:
120+
if quant_mode in self._no_cfg_modes:
116121
self._unset()
117-
else:
118-
assert None not in [self.activation_quantization_method, self.activation_n_bits, self.signedness]
119122

120123
@property
121124
def quant_mode(self):
@@ -146,9 +149,28 @@ def set_activation_quantization_param(self,
146149
147150
"""
148151
assert self.quant_mode in [ActivationQuantizationMode.QUANT, ActivationQuantizationMode.FLN_QUANT]
149-
# TODO shouldn't the whole self.activation_quantization_params be reset instead of updated?
150-
for param_name, param_value in activation_params.items():
151-
self.activation_quantization_params[param_name] = param_value
152+
self.activation_quantization_params = activation_params
153+
154+
def set_quant_config_attr(self, attr_name: str, value: Any):
155+
"""
156+
Update config's attribute.
157+
158+
Args:
159+
attr_name: attribute to set.
160+
value: value to set.
161+
"""
162+
if attr_name == 'quant_mode':
163+
self.set_quant_mode(value)
164+
else:
165+
if self.quant_mode in self._no_cfg_modes:
166+
raise ValueError(f'Cannot set attribute {attr_name} for activation with disabled quantization.')
167+
super().set_quant_config_attr(attr_name, value)
168+
169+
def _unset(self):
170+
""" Unset activation quantization fields to None. """
171+
self.activation_quantization_method = None
172+
self.activation_n_bits = 0
173+
self.signedness = None
152174

153175
def __eq__(self, other: Any) -> bool:
154176
"""
@@ -225,9 +247,13 @@ def set_weights_quantization_param(self,
225247
226248
"""
227249
assert self.enable_weights_quantization
228-
# TODO shouldn't the whole self.weights_quantization_params be reset instead of updated?
229-
for param_name, param_value in weights_params.items():
230-
self.weights_quantization_params[param_name] = param_value
250+
self.weights_quantization_params = weights_params
251+
252+
def _unset(self):
253+
self.weights_channels_axis = None
254+
self.weights_quantization_method = None
255+
self.weights_n_bits = 0
256+
self.weights_per_channel_threshold = None
231257

232258
def __eq__(self, other: Any) -> bool:
233259
"""
@@ -273,6 +299,8 @@ def __init__(self,
273299
node_attrs_list: A list of the node's weights attributes names.
274300
275301
"""
302+
# TODO it makes no sense that the same weights_channels_axis is going to all attrs
303+
276304
self.simd_size = op_cfg.simd_size
277305

278306
# Initialize a quantization configuration for each of the node's attributes
@@ -364,19 +392,22 @@ def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationCo
364392

365393
return attr_cfg
366394

367-
def set_attr_config(self, attr_name: 'WeightAttrT', attr_qc: WeightsAttrQuantizationConfig):
395+
def set_attr_config(self, attr_name: 'WeightAttrT', attr_qc: WeightsAttrQuantizationConfig, force=False):
368396
"""
369397
Adding a new attribute with quantization configuration to the node's weights configurations mapping.
370398
371399
Args:
372400
attr_name: The name of the attribute to set a quantization configuration to.
373401
attr_qc: The quantization configuration to set.
374-
402+
force: if True, the attribute is set without checking if it exists.
375403
"""
376-
if isinstance(attr_name, int):
404+
if attr_name in self.pos_attributes_config_mapping or (force and isinstance(attr_name, int)):
377405
self.pos_attributes_config_mapping[attr_name] = attr_qc
378-
else:
406+
elif attr_name in self.attributes_config_mapping or force:
407+
assert isinstance(attr_name, str)
379408
self.attributes_config_mapping[attr_name] = attr_qc
409+
else:
410+
raise ValueError(f'Unknown weights attr {attr_name}.')
380411

381412
def has_attribute_config(self, attr_name: 'WeightAttrT') -> bool:
382413
"""
@@ -389,13 +420,10 @@ def has_attribute_config(self, attr_name: 'WeightAttrT') -> bool:
389420
390421
"""
391422
if isinstance(attr_name, int):
392-
return self.pos_attributes_config_mapping.get(attr_name, False)
393-
else:
394-
saved_attr_name = self._extract_config_for_attributes_with_name(attr_name)
395-
if len(saved_attr_name) >= 1:
396-
return True
423+
return attr_name in self.pos_attributes_config_mapping
397424

398-
return False
425+
saved_attr_name = self._extract_config_for_attributes_with_name(attr_name)
426+
return len(saved_attr_name) >= 1
399427

400428
@property
401429
def all_weight_attrs(self) -> List['WeightAttrT']:
@@ -440,7 +468,7 @@ def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, Weigh
440468
return attrs_with_name
441469

442470
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
443-
attr_name: 'WeightAttrT' = None, *args: List[Any], **kwargs: Dict[str, Any]):
471+
attr_name: 'WeightAttrT' = None):
444472
"""
445473
This method overrides the parent class set_quant_config_attr to enable setting a specific weights
446474
attribute config parameter.
@@ -449,25 +477,35 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val
449477
attr_name: attribute name to change.
450478
config_parameter_name: parameter name to change.
451479
config_parameter_value: parameter value to change.
452-
args: A list of additional arguments.
453-
kwargs: A dictionary with additional key arguments.
454-
455480
"""
456-
457481
if attr_name is None:
458482
super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(config_parameter_name,
459-
config_parameter_value,
460-
*args, **kwargs)
461-
else:
462-
if self.has_attribute_config(attr_name):
463-
attr_cfg = self.get_attr_config(attr_name)
464-
if hasattr(attr_cfg, config_parameter_name):
465-
setattr(attr_cfg, config_parameter_name, config_parameter_value)
466-
else:
467-
raise AttributeError(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
468-
f"weights attribute {attr_name}.")
469-
else: # pragma: no cover
470-
Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
483+
config_parameter_value)
484+
return
485+
486+
if not self.has_attribute_config(attr_name):
487+
raise ValueError(
488+
f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
489+
490+
attr_cfg = self.get_attr_config(attr_name)
491+
if config_parameter_name == 'enable_weights_quantization':
492+
if config_parameter_value is False:
493+
attr_cfg.disable_quantization()
494+
elif attr_cfg.enable_weights_quantization is False:
495+
raise ValueError(f'Cannot enable quantization for attr {attr_name} with disabled quantization.')
496+
return
497+
498+
if not hasattr(attr_cfg, config_parameter_name):
499+
raise AttributeError(
500+
f"Parameter {config_parameter_name} could not be found in the quantization config of "
501+
f"weights attribute {attr_name}.")
502+
503+
if attr_cfg.enable_weights_quantization is False:
504+
# TODO we can add an option to reset the whole attr config for a specific attr, but this whole
505+
# mechanism should be revised. Also attr cfg code should be moved to attr cfg.
506+
raise ValueError(f'Cannot set param {config_parameter_name} for attr {attr_name} with disabled quantization.')
507+
508+
setattr(attr_cfg, config_parameter_name, config_parameter_value)
471509

472510
def __eq__(self, other: Any) -> bool:
473511
"""

model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def _apply_activation_bias_correction_to_node(node: BaseNode,
7171
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
7272
WeightsAttrQuantizationConfig(
7373
AttributeQuantizationConfig(
74-
enable_weights_quantization=False)))
74+
enable_weights_quantization=False)),
75+
force=True)
7576
else:
7677
# If the layer has bias, we subtract the correction from original bias
7778
node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction)

model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,5 @@ def _apply_bias_correction_to_node(node: BaseNode,
6868
node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node.
6969
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
7070
WeightsAttrQuantizationConfig(AttributeQuantizationConfig(
71-
enable_weights_quantization=False)))
71+
enable_weights_quantization=False)),
72+
force=True)

model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def set_second_moment_correction(qc):
138138
qc.weights_quantization_cfg.set_attr_config(attr,
139139
WeightsAttrQuantizationConfig(
140140
AttributeQuantizationConfig(
141-
enable_weights_quantization=False)))
141+
enable_weights_quantization=False)),
142+
force=True)
142143

143144
# Check if the source node was part of a fusion. If so, there are two cases:
144145
# either this is no longer a fusion, and the fusion info should be updated by removing

model_compression_toolkit/core/common/substitutions/shift_negative_activation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def op2d_bias_correction(op2d_node: BaseNode,
6767
for qc in op2d_node.candidates_quantization_cfg:
6868
qc.weights_quantization_cfg.set_attr_config(bias_flag_str,
6969
WeightsAttrQuantizationConfig(AttributeQuantizationConfig(
70-
enable_weights_quantization=False)))
70+
enable_weights_quantization=False)),
71+
force=True)
7172

7273
# Each node adds a different noise due to the shifting. It depends on the
7374
# dimensions of the kernel, thus the correction term is a function of

tests/keras_tests/function_tests/test_node_quantization_configurations.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def build_node(self, name='node', q_mode=ActivationQuantizationMode.QUANT):
5858
node.is_fln_quantization.return_value = False
5959

6060
activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=self.build_op_cfg())
61-
activation_quantization_cfg.quant_mode = q_mode
61+
activation_quantization_cfg.set_quant_mode(q_mode)
6262

6363
candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig)
6464
candidate_quantization_config.activation_quantization_cfg = activation_quantization_cfg

0 commit comments

Comments
 (0)