diff --git a/model_compression_toolkit/core/common/graph/base_graph.py b/model_compression_toolkit/core/common/graph/base_graph.py index 17ee171e3..520394187 100644 --- a/model_compression_toolkit/core/common/graph/base_graph.py +++ b/model_compression_toolkit/core/common/graph/base_graph.py @@ -874,7 +874,7 @@ def override_fused_node_activation_quantization_candidates(self): if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization: def update(qc): qc.activation_quantization_cfg = NodeActivationQuantizationConfig(fusing_op_quantization_cfg) - qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT + qc.activation_quantization_cfg.set_quant_mode(ActivationQuantizationMode.FLN_QUANT) node.quantization_cfg.update_all(update, remove_duplicates=True) else: node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.FLN_NO_QUANT) diff --git a/model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py b/model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py index dbfc36ad4..244920e51 100644 --- a/model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +++ b/model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py @@ -80,8 +80,7 @@ def __init__(self, origin_node: BaseNode, kernel_attr: str): base_quantization_cfg=None, validate=False ) for c in self.quantization_cfg.candidates_quantization_cfg: - c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT - c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH + c.activation_quantization_cfg.set_quant_mode(ActivationQuantizationMode.NO_QUANT) class VirtualSplitActivationNode(VirtualSplitNode): diff --git a/model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py b/model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py index 7765e9b06..5f8fe4fb4 100644 --- a/model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +++ b/model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py @@ -99,11 +99,9 @@ def _build_mp_model(self, graph, outputs, disable_activations: bool) -> Tuple[An # be added to the model). for n in evaluation_graph.get_topo_sorted_nodes(): if disable_activations or not n.has_configurable_activation(): - for c in n.candidates_quantization_cfg: - c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT + n.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.NO_QUANT) if not n.has_any_configurable_weight(): - for c in n.candidates_quantization_cfg: - c.weights_quantization_cfg.disable_all_weights_quantization() + n.quantization_cfg.disable_weights_quantization() model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph, mode=ModelBuilderMode.MIXEDPRECISION, diff --git a/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py index 929f8618f..67613200b 100644 --- a/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py @@ -46,6 +46,20 @@ class NodeQuantizationConfig: validate: InitVar[bool] = True + def __post_init__(self, validate=True): + if validate: + if not any(self.base_quantization_cfg == qc for qc in self.candidates_quantization_cfg): + raise ValueError('Candidates should contain the base config.') + self._validate_consistent_activation_quant_mode() + self._validate_consistent_weights_quant_mode() + + self.remove_duplicates() + + # TODO irena + # for now make sure they are separate objects so that one doesnt inadvertently modify the other + if any(self.base_quantization_cfg is qc for qc in self.candidates_quantization_cfg): + self.base_quantization_cfg = copy.deepcopy(self.base_quantization_cfg) + def update_all(self, update_fn: Callable[[CandidateNodeQuantizationConfig], None], remove_duplicates: bool = True): """ Apply update function on the base config and all candidates configs. @@ -69,7 +83,7 @@ def update_activation_quantization_mode(self, mode: ActivationQuantizationMode): mode: quantization mode. """ def fn(c): - c.activation_quantization_cfg.quant_mode = mode + c.activation_quantization_cfg.set_quant_mode(mode) self.update_all(fn) @@ -102,17 +116,6 @@ def remove_duplicates(self): uniq_qcs.append(qc) self.candidates_quantization_cfg = uniq_qcs - def __post_init__(self, validate=True): - if validate: - if not any(self.base_quantization_cfg == qc for qc in self.candidates_quantization_cfg): - raise ValueError('Candidates should contain the base config.') - self._validate_consistent_activation_quant_mode() - self._validate_consistent_weights_quant_mode() - # TODO irena - # for now make sure they are separate objects so that one doesnt inadvertently modify the other - if any(self.base_quantization_cfg is qc for qc in self.candidates_quantization_cfg): - self.base_quantization_cfg = copy.deepcopy(self.base_quantization_cfg) - def _validate_consistent_activation_quant_mode(self): """ Validate that base config and all candidates configs contain identical activation quantization mode. diff --git a/model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py b/model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py deleted file mode 100644 index 6bea53721..000000000 --- a/model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import copy -from typing import List - -from mct_quantizers import QuantizationMethod -from model_compression_toolkit.core.common import Graph, BaseNode -from model_compression_toolkit.constants import FLOAT_BITWIDTH -from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \ - CandidateNodeQuantizationConfig - -def filter_nodes_candidates(graph: Graph): - """ - Filters the graph's nodes candidates configuration list. - We apply this after mark activation operation to eliminate nodes that their activation are no longer being quantized - from the mixed-precision search. - Updating the lists is preformed inplace on the graph object. - - Args: - graph: Graph for which to add quantization info to each node. - """ - nodes = list(graph.nodes) - for n in nodes: - n.quantization_cfg.candidates_quantization_cfg = filter_node_candidates(node=n) - - return graph - - -def _filter_bit_method_dups(candidates: List[CandidateNodeQuantizationConfig], - kernel_attr: str = None) -> List[CandidateNodeQuantizationConfig]: - """ - Filters out duplications in candidates configuration list, based on similarity in - (weights_n_bits, weights_quantization_method, activation_n_bits, activation_quantization_method). - Weights quantization configuration considers only kernel attributes. - - Args: - candidates: A list of quantization configuration candidates. - kernel_attr: The name of the node's kernel attribute if such exists. - - Returns: A filtered list of quantization configuration candidates. - - """ - seen_bits_method_combinations = set() - final_candidates = [] - for c in candidates: - weight_n_bits = None if kernel_attr is None else ( - c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits) - weights_quantization_method = None if kernel_attr is None else ( - c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_quantization_method) - comb = (weight_n_bits, - weights_quantization_method, - c.activation_quantization_cfg.activation_n_bits, - c.activation_quantization_cfg.activation_quantization_method) - if comb not in seen_bits_method_combinations: - final_candidates.append(c) - seen_bits_method_combinations.add(comb) - - return final_candidates - - -def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConfig]: - """ - Updates a node's candidates configuration list. - If the node's weights quantization is disabled (or it only has activations to quantize), then the updated list - will have a candidate with any of the different original activation bitwidths candidates and a default value - for its weights bitwidth (that doesn't have any impact on the quantization or the mixed-precision search. - If the node's activation quantization is disabled, the same filtering applied for the weights bitwidth candidates. - - Args: - node: Node to set its quantization configurations. - - """ - - filtered_candidates = copy.deepcopy(node.candidates_quantization_cfg) - final_candidates = copy.deepcopy(node.candidates_quantization_cfg) - - if (node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr)) and node.is_no_quantization(): - # If activation quantization is disabled and the node doesn't have a kernel or doesn't quantize the kernel, - # but for some reason the node has multiple candidates then replace it with a single dummy candidate with - # default bit-width values. - single_dummy_candidate = filtered_candidates[0] - single_dummy_candidate.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH - single_dummy_candidate.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO - - if node.kernel_attr is not None: - kernel_config = single_dummy_candidate.weights_quantization_cfg.get_attr_config(node.kernel_attr) - kernel_config.weights_n_bits = FLOAT_BITWIDTH - kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO - - final_candidates = [single_dummy_candidate] - - elif node.is_no_quantization(): - # Remove candidates that have duplicated weights candidates for node with disabled activation quantization. - # Replacing the activation n_bits in the remained configurations with default value to prevent confusion. - # Set the config of the non-quantized FLN node to POWER_OF_TWO. - seen_candidates = set() - filtered_candidates = [candidate for candidate in filtered_candidates if - candidate.weights_quantization_cfg not in seen_candidates - and not seen_candidates.add(candidate.weights_quantization_cfg)] - - for c in filtered_candidates: - c.activation_quantization_cfg.activation_n_bits = FLOAT_BITWIDTH - c.activation_quantization_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO - - final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr) - - elif node.is_fln_no_quantization() or node.is_fln_quantization(): - # Remove candidates that have duplicated weights candidates for node with disabled activation quantization. - seen_candidates = set() - filtered_candidates = [candidate for candidate in filtered_candidates if - candidate.weights_quantization_cfg not in seen_candidates - and not seen_candidates.add(candidate.weights_quantization_cfg)] - final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr) - - elif node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr): - # TODO: - # To allow MP on positional weights we need to modify this to consider all weights not only kernel. - # Remove candidates that have duplicated activation candidates for node with disabled weights quantization. - # Replacing the weights n_bits in the remained configurations with default value to prevent confusion. - seen_candidates = set() - filtered_candidates = [candidate for candidate in filtered_candidates if - candidate.activation_quantization_cfg not in seen_candidates - and not seen_candidates.add(candidate.activation_quantization_cfg)] - - for c in filtered_candidates: - if node.kernel_attr is not None: - kernel_config = c.weights_quantization_cfg.get_attr_config(node.kernel_attr) - kernel_config.weights_n_bits = FLOAT_BITWIDTH - kernel_config.weights_quantization_method = QuantizationMethod.POWER_OF_TWO - - final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr) - - return final_candidates diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index d93eb5acb..3ed3a5d51 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -46,18 +46,13 @@ class BaseNodeQuantizationConfig(object): Base class for node quantization configuration """ - def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any, - *args: List[Any], **kwargs: Dict[str, Any]): + def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any): """ Changes a BaseNodeQuantizationConfig's parameter. - Note that arg and kwargs are only to allow clean override in the child classes. Args: config_parameter_name: parameter name to change. config_parameter_value: parameter value to change. - args: A list of additional arguments. - kwargs: A dictionary with additional key arguments. - """ if hasattr(self, config_parameter_name): setattr(self, config_parameter_name, config_parameter_value) @@ -77,6 +72,12 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig): """ Attributes for configuring the quantization of the activations of a node. """ + _no_cfg_modes = [ + ActivationQuantizationMode.NO_QUANT, + ActivationQuantizationMode.FLN_NO_QUANT, + ActivationQuantizationMode.PRESERVE_QUANT + ] + def __init__(self, op_cfg: OpQuantizationConfig): """ @@ -85,15 +86,18 @@ def __init__(self, op_cfg: OpQuantizationConfig): """ self.activation_quantization_method = op_cfg.activation_quantization_method self.activation_n_bits = op_cfg.activation_n_bits + self.signedness = op_cfg.signedness + if op_cfg.enable_activation_quantization and op_cfg.quantization_preserving: raise ValueError("An OpQuantizationConfig can't have both enable_activation_quantization and quantization_preserving enabled.") + + self._quant_mode = None if op_cfg.enable_activation_quantization: - self.quant_mode = ActivationQuantizationMode.QUANT + self.set_quant_mode(ActivationQuantizationMode.QUANT) elif op_cfg.quantization_preserving: - self.quant_mode = ActivationQuantizationMode.PRESERVE_QUANT + self.set_quant_mode(ActivationQuantizationMode.PRESERVE_QUANT) else: - self.quant_mode = ActivationQuantizationMode.NO_QUANT - self.signedness = op_cfg.signedness + self.set_quant_mode(ActivationQuantizationMode.NO_QUANT) self.activation_quantization_params = {} # TODO: computed by compute_activation_bias_correction. Probably shouldnt be here. @@ -102,6 +106,28 @@ def __init__(self, op_cfg: OpQuantizationConfig): # Since activation qparams are re-computed in several places, it's easier to keep it here and update it once. self.z_threshold = None + def set_quant_mode(self, quant_mode: ActivationQuantizationMode): + """ + Set quantization mode. If no configuration is associated with the quant_mode, it's un-set. + + Args: + quant_mode: quantization mode to set. + """ + if quant_mode in [ActivationQuantizationMode.QUANT, ActivationQuantizationMode.FLN_QUANT]: + if self.quant_mode in self._no_cfg_modes: + raise ValueError(f'Cannot change quant_mode to {quant_mode.name} from {self.quant_mode.name}.') + self._quant_mode = quant_mode + if quant_mode in self._no_cfg_modes: + self._unset() + + @property + def quant_mode(self): + return self._quant_mode + + @quant_mode.setter + def quant_mode(self, mode): + raise RuntimeError('quant_mode cannot be set directly. Use set_quant_mode.') + @property def enable_activation_quantization(self): return self.quant_mode == ActivationQuantizationMode.QUANT @@ -122,9 +148,29 @@ def set_activation_quantization_param(self, activation_params: Dictionary that contains weight quantization params. """ - assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT - for param_name, param_value in activation_params.items(): - self.activation_quantization_params[param_name] = param_value + assert self.quant_mode in [ActivationQuantizationMode.QUANT, ActivationQuantizationMode.FLN_QUANT] + self.activation_quantization_params = activation_params + + def set_quant_config_attr(self, attr_name: str, value: Any): + """ + Update config's attribute. + + Args: + attr_name: attribute to set. + value: value to set. + """ + if attr_name == 'quant_mode': + self.set_quant_mode(value) + else: + if self.quant_mode in self._no_cfg_modes: + raise ValueError(f'Cannot set attribute {attr_name} for activation with disabled quantization.') + super().set_quant_config_attr(attr_name, value) + + def _unset(self): + """ Unset activation quantization fields to None. """ + self.activation_quantization_method = None + self.activation_n_bits = 0 + self.signedness = None def __eq__(self, other: Any) -> bool: """ @@ -164,14 +210,33 @@ def __init__(self, weights_attr_cfg: AttributeQuantizationConfig with parameters to use when creating the node's attribute quantization config. weights_channels_axis: Axis to quantize a node's attribute when quantizing per-channel (if not quantizing per-channel than expecting None). """ + if weights_attr_cfg.lut_values_bitwidth is not None: + raise ValueError('None-default lut_values_bitwidth in AttributeQuantizationConfig is not supported.') + self.weights_channels_axis = weights_channels_axis self.weights_quantization_method = weights_attr_cfg.weights_quantization_method self.weights_n_bits = weights_attr_cfg.weights_n_bits self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold - self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization + + self._enable_weights_quantization = weights_attr_cfg.enable_weights_quantization + if weights_attr_cfg.enable_weights_quantization is False: + self._unset() self.weights_quantization_params = {} + @property + def enable_weights_quantization(self): + return self._enable_weights_quantization + + @enable_weights_quantization.setter + def enable_weights_quantization(self, flag): + raise RuntimeError('enable_quantization should not be set directly, use disable_quantization() or ' + 'create a new instance.') + + def disable_quantization(self): + self._enable_weights_quantization = False + self._unset() + def set_weights_quantization_param(self, weights_params: dict): """ @@ -182,8 +247,13 @@ def set_weights_quantization_param(self, """ assert self.enable_weights_quantization - for param_name, param_value in weights_params.items(): - self.weights_quantization_params[param_name] = param_value + self.weights_quantization_params = weights_params + + def _unset(self): + self.weights_channels_axis = None + self.weights_quantization_method = None + self.weights_n_bits = 0 + self.weights_per_channel_threshold = None def __eq__(self, other: Any) -> bool: """ @@ -229,6 +299,8 @@ def __init__(self, node_attrs_list: A list of the node's weights attributes names. """ + # TODO it makes no sense that the same weights_channels_axis is going to all attrs + self.simd_size = op_cfg.simd_size # Initialize a quantization configuration for each of the node's attributes @@ -320,19 +392,22 @@ def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationCo return attr_cfg - def set_attr_config(self, attr_name: 'WeightAttrT', attr_qc: WeightsAttrQuantizationConfig): + def set_attr_config(self, attr_name: 'WeightAttrT', attr_qc: WeightsAttrQuantizationConfig, force=False): """ Adding a new attribute with quantization configuration to the node's weights configurations mapping. Args: attr_name: The name of the attribute to set a quantization configuration to. attr_qc: The quantization configuration to set. - + force: if True, the attribute is set without checking if it exists. """ - if isinstance(attr_name, int): + if attr_name in self.pos_attributes_config_mapping or (force and isinstance(attr_name, int)): self.pos_attributes_config_mapping[attr_name] = attr_qc - else: + elif attr_name in self.attributes_config_mapping or force: + assert isinstance(attr_name, str) self.attributes_config_mapping[attr_name] = attr_qc + else: + raise ValueError(f'Unknown weights attr {attr_name}.') def has_attribute_config(self, attr_name: 'WeightAttrT') -> bool: """ @@ -345,13 +420,10 @@ def has_attribute_config(self, attr_name: 'WeightAttrT') -> bool: """ if isinstance(attr_name, int): - return self.pos_attributes_config_mapping.get(attr_name, False) - else: - saved_attr_name = self._extract_config_for_attributes_with_name(attr_name) - if len(saved_attr_name) >= 1: - return True + return attr_name in self.pos_attributes_config_mapping - return False + saved_attr_name = self._extract_config_for_attributes_with_name(attr_name) + return len(saved_attr_name) >= 1 @property def all_weight_attrs(self) -> List['WeightAttrT']: @@ -373,9 +445,9 @@ def get_all_weight_attrs_configs(self) -> Dict['WeightAttrT', AttributeQuantizat def disable_all_weights_quantization(self): """ Disable quantization for all weights. """ for w_cfg in self.pos_attributes_config_mapping.values(): - w_cfg.enable_weights_quantization = False + w_cfg.disable_quantization() for w_cfg in self.attributes_config_mapping.values(): - w_cfg.enable_weights_quantization = False + w_cfg.disable_quantization() def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, WeightsAttrQuantizationConfig]: """ @@ -396,7 +468,7 @@ def _extract_config_for_attributes_with_name(self, attr_name) -> Dict[str, Weigh return attrs_with_name def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any, - attr_name: 'WeightAttrT' = None, *args: List[Any], **kwargs: Dict[str, Any]): + attr_name: 'WeightAttrT' = None): """ This method overrides the parent class set_quant_config_attr to enable setting a specific weights attribute config parameter. @@ -405,25 +477,35 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val attr_name: attribute name to change. config_parameter_name: parameter name to change. config_parameter_value: parameter value to change. - args: A list of additional arguments. - kwargs: A dictionary with additional key arguments. - """ - if attr_name is None: super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(config_parameter_name, - config_parameter_value, - *args, **kwargs) - else: - if self.has_attribute_config(attr_name): - attr_cfg = self.get_attr_config(attr_name) - if hasattr(attr_cfg, config_parameter_name): - setattr(attr_cfg, config_parameter_name, config_parameter_value) - else: - raise AttributeError(f"Parameter {config_parameter_name} could not be found in the node quantization config of " - f"weights attribute {attr_name}.") - else: # pragma: no cover - Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.") + config_parameter_value) + return + + if not self.has_attribute_config(attr_name): + raise ValueError( + f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.") + + attr_cfg = self.get_attr_config(attr_name) + if config_parameter_name == 'enable_weights_quantization': + if config_parameter_value is False: + attr_cfg.disable_quantization() + elif attr_cfg.enable_weights_quantization is False: + raise ValueError(f'Cannot enable quantization for attr {attr_name} with disabled quantization.') + return + + if not hasattr(attr_cfg, config_parameter_name): + raise AttributeError( + f"Parameter {config_parameter_name} could not be found in the quantization config of " + f"weights attribute {attr_name}.") + + if attr_cfg.enable_weights_quantization is False: + # TODO we can add an option to reset the whole attr config for a specific attr, but this whole + # mechanism should be revised. Also attr cfg code should be moved to attr cfg. + raise ValueError(f'Cannot set param {config_parameter_name} for attr {attr_name} with disabled quantization.') + + setattr(attr_cfg, config_parameter_name, config_parameter_value) def __eq__(self, other: Any) -> bool: """ diff --git a/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py index 60e94f094..fe40fb037 100644 --- a/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py @@ -71,7 +71,8 @@ def _apply_activation_bias_correction_to_node(node: BaseNode, node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS, WeightsAttrQuantizationConfig( AttributeQuantizationConfig( - enable_weights_quantization=False))) + enable_weights_quantization=False)), + force=True) else: # If the layer has bias, we subtract the correction from original bias node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction) diff --git a/model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py b/model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py index 962e92c76..b2c27adb6 100644 --- a/model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +++ b/model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py @@ -68,4 +68,5 @@ def _apply_bias_correction_to_node(node: BaseNode, node.framework_attr[fw_impl.constants.USE_BIAS] = True # Mark the use_bias attribute of the node. node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS, WeightsAttrQuantizationConfig(AttributeQuantizationConfig( - enable_weights_quantization=False))) + enable_weights_quantization=False)), + force=True) diff --git a/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py b/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py index cc465186a..d47e9816b 100644 --- a/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +++ b/model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py @@ -125,12 +125,12 @@ def set_second_moment_correction(qc): bn_node.quantization_cfg = copy.deepcopy(source_node.quantization_cfg) for qc in bn_node.candidates_quantization_cfg: - qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT + qc.activation_quantization_cfg.set_quant_mode(ActivationQuantizationMode.NO_QUANT) for attr in bn_node.get_node_weights_attributes(): if qc.weights_quantization_cfg.has_attribute_config(attr): # we only create a BN layer to collect statistics, so we don't need to quantize anything, # but we do need to add the BN attributes to the reconstructed node. - qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False + qc.weights_quantization_cfg.get_attr_config(attr).disable_quantization() else: # setting a "dummy" attribute configuration with disabled quantization. # TODO: once enabling BN attributes quantization, need to figure out if thie @@ -138,7 +138,8 @@ def set_second_moment_correction(qc): qc.weights_quantization_cfg.set_attr_config(attr, WeightsAttrQuantizationConfig( AttributeQuantizationConfig( - enable_weights_quantization=False))) + enable_weights_quantization=False)), + force=True) # Check if the source node was part of a fusion. If so, there are two cases: # either this is no longer a fusion, and the fusion info should be updated by removing diff --git a/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py b/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py index 4a2aed524..6da02f1c0 100644 --- a/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +++ b/model_compression_toolkit/core/common/substitutions/shift_negative_activation.py @@ -67,7 +67,8 @@ def op2d_bias_correction(op2d_node: BaseNode, for qc in op2d_node.candidates_quantization_cfg: qc.weights_quantization_cfg.set_attr_config(bias_flag_str, WeightsAttrQuantizationConfig(AttributeQuantizationConfig( - enable_weights_quantization=False))) + enable_weights_quantization=False)), + force=True) # Each node adds a different noise due to the shifting. It depends on the # dimensions of the kernel, thus the correction term is a function of @@ -424,7 +425,7 @@ def shift_negative_function(graph: Graph, fqc=graph.fqc) for candidate_qc in pad_node.candidates_quantization_cfg: - candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT + candidate_qc.activation_quantization_cfg.set_quant_mode(ActivationQuantizationMode.NO_QUANT) for attr in pad_node.get_node_weights_attributes(): candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False diff --git a/model_compression_toolkit/core/graph_prep_runner.py b/model_compression_toolkit/core/graph_prep_runner.py index fe02eee18..58f64b92a 100644 --- a/model_compression_toolkit/core/graph_prep_runner.py +++ b/model_compression_toolkit/core/graph_prep_runner.py @@ -19,7 +19,6 @@ from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.graph.base_graph import Graph from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig -from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG, \ QuantizationErrorMethod from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig @@ -167,11 +166,6 @@ def get_finalized_graph(initial_graph: Graph, if tb_w is not None: tb_w.add_graph(transformed_graph, 'after_graph_marking') - ###################################### - # Filter nodes' candidates - ###################################### - transformed_graph = filter_nodes_candidates(transformed_graph) - if tb_w is not None: tb_w.add_graph(transformed_graph, 'after_candidates_filtering') diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/node_filter_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/node_filter_test.py index e21f2ffca..a64092e89 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/node_filter_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/network_editor/node_filter_test.py @@ -44,9 +44,10 @@ class ScopeFilterTest(BaseKerasFeatureNetworkTest): - Check attribute changes ''' - def __init__(self, unit_test, activation_n_bits: int = 3, weights_n_bits: int = 3): - self.activation_n_bits = activation_n_bits - self.weights_n_bits = weights_n_bits + def __init__(self, unit_test): + self.activation_n_bits = 5 + self.weights_n_bits = 3 + self.weights_n_bits2 = 2 self.kernel = 3 self.num_conv_channels = 4 self.scope = 'scope' @@ -73,12 +74,9 @@ def get_debug_config(self): EditRule(filter=NodeNameScopeFilter(self.scope), action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL, weights_n_bits=self.weights_n_bits)), - EditRule(filter=NodeNameScopeFilter('change_2'), - action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL, - enable_weights_quantization=True)), EditRule(filter=NodeNameScopeFilter('change_2') or NodeNameScopeFilter('does_not_exist'), action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL, - enable_weights_quantization=False)) + weights_n_bits=self.weights_n_bits2)) ] return mct.core.DebugConfig(network_editor=network_editor) @@ -107,10 +105,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= self.unit_test.assertTrue( len(np.unique(conv_layers[1].get_quantized_weights()['kernel'].numpy())) in [2 ** (self.weights_n_bits) - 1, 2 ** (self.weights_n_bits)]) + self.unit_test.assertTrue( + len(np.unique(conv_layers[2].get_quantized_weights()['kernel'].numpy())) in [2 ** (self.weights_n_bits2) - 1, + 2 ** (self.weights_n_bits2)]) + # check that this conv's weights did not change self.unit_test.assertTrue(np.all(conv_layers[0].get_quantized_weights()['kernel'].numpy() == self.conv_w)) - # check that this conv's weights did not change - self.unit_test.assertTrue(np.all(conv_layers[2].kernel == self.conv_w)) holder_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder) self.unit_test.assertTrue(holder_layers[1].activation_holder_quantizer.get_config()['num_bits'] == 16) self.unit_test.assertTrue(holder_layers[2].activation_holder_quantizer.get_config()['num_bits'] == self.activation_n_bits) diff --git a/tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py b/tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py index 49da66d7c..674809992 100644 --- a/tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py +++ b/tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py @@ -36,7 +36,6 @@ from model_compression_toolkit.core.common.graph.virtual_activation_weights_node import VirtualSplitActivationNode, \ VirtualActivationWeightsNode, VirtualSplitWeightsNode -from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates from model_compression_toolkit.core.keras.graph_substitutions.substitutions.virtual_activation_weights_composition import \ VirtualActivationWeightsComposition from model_compression_toolkit.core.keras.graph_substitutions.substitutions.weights_activation_split import \ @@ -113,8 +112,6 @@ def prepare_graph(in_model, keras_impl, mixed_precision_candidates_list, base_co fqc = attach2keras.attach(tpc, qc.custom_tpc_opset_to_layer) graph = load_fqc_configuration(graph, fqc) - graph = filter_nodes_candidates(graph) - return graph @@ -217,7 +214,7 @@ def test_two_conv_net_compose_after_split_activation_only(self): graph.skip_validation_check = False - self._verify_two_conv_with_split_test(graph, v_graph, 9, 3) + self._verify_two_conv_with_split_test(graph, v_graph, 3, 3) def test_all_weights_layers_composition(self): in_model = multiple_weights_nodes_model() diff --git a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py b/tests/keras_tests/function_tests/test_cfg_candidates_filter.py deleted file mode 100644 index 8d507178d..000000000 --- a/tests/keras_tests/function_tests/test_cfg_candidates_filter.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2022 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import keras -import unittest -from tensorflow.keras.layers import Conv2D, ReLU, Input, InputLayer - -from model_compression_toolkit.constants import FLOAT_BITWIDTH -from model_compression_toolkit.core import CustomOpsetLayers -from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates -from model_compression_toolkit.core.keras.constants import KERNEL -from model_compression_toolkit.core.common.framework_info import set_fw_info -from model_compression_toolkit.core.keras.default_framework_info import KerasInfo -from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation -from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration -from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \ - AttachTpcToKeras -from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc -from tests.keras_tests.tpc_keras import get_tpc_with_activation_mp_keras - - -def get_full_bitwidth_candidates(): - return [(4, 8), (4, 4), (4, 2), - (8, 8), (8, 4), (8, 2), - (2, 8), (2, 4), (2, 2)] - - -def prepare_graph(in_model, base_config, default_config, bitwidth_candidates): - tpc = get_tpc_with_activation_mp_keras(base_config=base_config, - mp_bitwidth_candidates_list=bitwidth_candidates, - name="candidates_filter_test", - default_config=default_config) - - keras_impl = KerasImplementation() - graph = keras_impl.model_reader(in_model, None) # model reading - - attach2keras = AttachTpcToKeras() - fqc = attach2keras.attach(tpc, custom_opset2layer={"Input": CustomOpsetLayers([InputLayer])}) - - graph = load_fqc_configuration(graph, fqc) - - return graph - - -def create_model_conv2d_only(input_shape): - inputs = Input(shape=input_shape) - x = Conv2D(2, 3)(inputs) - outputs = Conv2D(2, 3)(x) - return keras.Model(inputs=inputs, outputs=outputs) - - -def create_model_single_conv2d(input_shape): - inputs = Input(shape=input_shape) - outputs = Conv2D(2, 3)(inputs) - return keras.Model(inputs=inputs, outputs=outputs) - - -def create_model_conv2d_relu(input_shape): - inputs = Input(shape=input_shape) - x = Conv2D(2, 3)(inputs) - outputs = ReLU()(x) - return keras.Model(inputs=inputs, outputs=outputs) - - -class TestCfgCandidatesFilter(unittest.TestCase): - def setUp(self): - set_fw_info(KerasInfo) - - def test_cfg_filter_activation_only_nodes(self): - input_shape = (8, 8, 3) - in_model = create_model_conv2d_relu(input_shape) - - base_config = generate_test_op_qc(**generate_test_attr_configs()) - default_config = base_config.clone_and_edit(attr_weights_configs_mapping={}) - - graph = prepare_graph(in_model, - base_config=base_config, - bitwidth_candidates=get_full_bitwidth_candidates(), - default_config=default_config) - - # Filtering nodes; candidates - filtered_graph = filter_nodes_candidates(graph) - - filtered_configurable_nodes = filtered_graph.get_configurable_sorted_nodes() - - # checking that layers with activation only (input and relu) have filtered configurations list, - # that they have a configuration for each of the original bitwidth options - input_candidates = filtered_configurable_nodes[0].candidates_quantization_cfg - self.assertTrue(len(input_candidates) == 3, - f"Expects 3 input layer candidates, number of candidates is {len(input_candidates)}") - self.assertTrue([c.activation_quantization_cfg.activation_n_bits for c in input_candidates] == [8, 4, 2]) - - relu_candidates = filtered_configurable_nodes[2].candidates_quantization_cfg - self.assertTrue(len(relu_candidates) == 3, - f"Expects 3 input layer candidates, number of candidates is {len(relu_candidates)}") - self.assertTrue([c.activation_quantization_cfg.activation_n_bits for c in relu_candidates] == [8, 4, 2]) - - def test_cfg_filter_weights_disabled(self): - input_shape = (8, 8, 3) - in_model = create_model_conv2d_only(input_shape) - - base_config = generate_test_op_qc(**generate_test_attr_configs(enable_kernel_weights_quantization=False)) - default_config = base_config.clone_and_edit(attr_weights_configs_mapping={}) - - graph = prepare_graph(in_model, - base_config=base_config, - bitwidth_candidates=get_full_bitwidth_candidates(), - default_config=default_config) - - # Filtering nodes; candidates - filtered_graph = filter_nodes_candidates(graph) - - filtered_configurable_nodes = filtered_graph.get_configurable_sorted_nodes() - - # checking that layers with weights (conv2d) have filtered activation configurations list - # when weights quantization is disabled - conv2d_1_candidates = filtered_configurable_nodes[1].candidates_quantization_cfg - self.assertTrue(len(conv2d_1_candidates) == 3, - f"Expects 3 Conv layer candidates, number of candidates is {len(conv2d_1_candidates)}") - self.assertTrue([c.activation_quantization_cfg.activation_n_bits for c in conv2d_1_candidates] == [8, 4, 2]) - conv2d_2_candidates = filtered_configurable_nodes[1].candidates_quantization_cfg - self.assertTrue(len(conv2d_2_candidates) == 3, - f"Expects 3 Conv layer candidates, number of candidates is {len(conv2d_2_candidates)}") - self.assertTrue([c.activation_quantization_cfg.activation_n_bits for c in conv2d_2_candidates] == [8, 4, 2]) - - def test_cfg_filter_activation_disabled(self): - input_shape = (8, 8, 3) - in_model = create_model_conv2d_relu(input_shape) - - base_config = generate_test_op_qc(enable_activation_quantization=False, - **generate_test_attr_configs()) - default_config = base_config.clone_and_edit(attr_weights_configs_mapping={}) - - graph = prepare_graph(in_model, - base_config=base_config, - bitwidth_candidates=get_full_bitwidth_candidates(), - default_config=default_config) - - # Filtering nodes; candidates - filtered_graph = filter_nodes_candidates(graph) - - filtered_configurable_nodes = filtered_graph.get_configurable_sorted_nodes() - - # checking that layers with weights (conv2d) have filtered weights configurations list - # when activation quantization is disabled - conv2d_kernel_candidates = filtered_configurable_nodes[0].get_all_weights_attr_candidates(KERNEL) - self.assertTrue(len(conv2d_kernel_candidates) == 3, - f"Expects 3 Conv layer kernel candidates, number of candidates is {len(conv2d_kernel_candidates)}") - self.assertTrue([c.weights_n_bits for c in conv2d_kernel_candidates] == [8, 4, 2]) - - def test_cfg_filter_multiple_candidates_weights_disabled(self): - input_shape = (8, 8, 3) - in_model = create_model_single_conv2d(input_shape) - - base_config = generate_test_op_qc(**generate_test_attr_configs(enable_kernel_weights_quantization=False)) - default_config = base_config.clone_and_edit(attr_weights_configs_mapping={}) - - graph = prepare_graph(in_model, - base_config=base_config, - bitwidth_candidates=[(8, 8), (4, 8), (2, 8)], - default_config=default_config) - - # Filtering nodes; candidates - filtered_graph = filter_nodes_candidates(graph) - - filtered_graph_nodes = filtered_graph.get_topo_sorted_nodes() - - # checking that layers with weights (conv2d) have filtered weights configurations list - # when activation quantization is disabled - conv2d_candidates = filtered_graph_nodes[1].candidates_quantization_cfg - self.assertTrue(len(conv2d_candidates) == 1, - f"Expects 1 Conv layer candidates, number of candidates is {len(conv2d_candidates)}") - candidate = conv2d_candidates[0] - self.assertTrue((candidate.weights_quantization_cfg.get_attr_config(KERNEL).weights_n_bits, - candidate.activation_quantization_cfg.activation_n_bits) == (FLOAT_BITWIDTH, 8)) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/keras_tests/function_tests/test_node_quantization_configurations.py b/tests/keras_tests/function_tests/test_node_quantization_configurations.py deleted file mode 100644 index 681bbac6b..000000000 --- a/tests/keras_tests/function_tests/test_node_quantization_configurations.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import copy -import unittest -from mct_quantizers import QuantizationMethod - -from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping -from model_compression_toolkit.core.common.quantization.node_quantization_config import \ - NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig, WeightsAttrQuantizationConfig -from model_compression_toolkit.core.keras.constants import KERNEL, BIAS -from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig -from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import get_op_quantization_configs - - -class TestNodeQuantizationConfigurations(unittest.TestCase): - - def test_activation_set_quant_config_attribute(self): - op_cfg, _, _ = get_op_quantization_configs() - - nac = NodeActivationQuantizationConfig(op_cfg) - og_nac = copy.deepcopy(nac) - - self.assertTrue(nac.activation_n_bits == 8) - nac.set_quant_config_attr("activation_n_bits", 4) - self.assertTrue(nac.activation_n_bits == 4, "Expects set_quant_config_attr to be successful, " - "new activation_n_bits should be 4.") - self.assertFalse(nac == og_nac) - - with self.assertRaises(AttributeError) as e: - nac.set_quant_config_attr("activation_M_bits", 8) - self.assertEqual(str(e.exception), - "Parameter activation_M_bits could not be found in the node quantization config.") - - def test_weights_set_quant_config_attribute(self): - op_cfg, _, _ = get_op_quantization_configs() - - nwc = NodeWeightsQuantizationConfig(op_cfg, - weights_channels_axis=ChannelAxisMapping(1, -1), - node_attrs_list=[KERNEL, 0]) - og_nwc = copy.deepcopy(nwc) - - # Updating an attribute parameter - self.assertTrue(nwc.get_attr_config(KERNEL).weights_n_bits, 8) - nwc.set_quant_config_attr("weights_n_bits", 4, attr_name=KERNEL) - self.assertFalse(nwc.get_attr_config(KERNEL).weights_n_bits == 8, - f"Expects set_quant_config_attr to update {KERNEL} attribute weights_n_bits to 4.") - self.assertFalse(nwc == og_nwc) - - nwc = copy.deepcopy(og_nwc) - self.assertTrue(nwc.get_attr_config(0).weights_n_bits, 8) - nwc.set_quant_config_attr("weights_n_bits", 4, attr_name=0) - self.assertFalse(nwc.get_attr_config(0).weights_n_bits == 8, - f"Expects set_quant_config_attr to update positional attribute weights_n_bits to 4.") - self.assertFalse(nwc == og_nwc) - - with self.assertRaises(AttributeError) as e: - nwc.set_quant_config_attr("weights_M_bits", 4, attr_name=KERNEL) - self.assertEqual(str(e.exception), f"Parameter weights_M_bits could not be found in the node quantization " - f"config of weights attribute {KERNEL}.") - - def test_get_weights_attr_config(self): - op_cfg, _, _ = get_op_quantization_configs() - - # Init a config with regular and positional attributes, and attributes with overlapping names, since in the - # implementation we look for existence of a string to retrieve attribute - nwc = NodeWeightsQuantizationConfig(op_cfg, - weights_channels_axis=ChannelAxisMapping(1, -1), - node_attrs_list=[KERNEL, 0, BIAS, f"{BIAS}-2"]) - - kernel_attr = nwc.get_attr_config(KERNEL) - self.assertTrue(kernel_attr.weights_n_bits == 8) # sanity - - pos_attr = nwc.get_attr_config(0) - self.assertTrue(pos_attr.weights_quantization_method == QuantizationMethod.POWER_OF_TWO) # sanity (should use default config) - - bias_attr = nwc.get_attr_config(BIAS) - self.assertTrue(bias_attr.weights_n_bits == 8) # checking successful retrival - - bias2_attr = nwc.get_attr_config(f"{BIAS}-2") - self.assertTrue(bias_attr.weights_n_bits == 8) # checking successful retrival - - self.assertFalse(bias_attr is bias2_attr) # this is "is" on purpose, to compare addresses - - def test_set_weights_attr_config(self): - op_cfg, _, _ = get_op_quantization_configs() - - nwc = NodeWeightsQuantizationConfig(op_cfg, - weights_channels_axis=ChannelAxisMapping(1, -1), - node_attrs_list=[KERNEL, 0]) - - new_cfg = WeightsAttrQuantizationConfig(weights_attr_cfg=AttributeQuantizationConfig(weights_n_bits=4)) - - kernel_attr = copy.deepcopy(nwc.get_attr_config(KERNEL)) - nwc.set_attr_config(KERNEL, new_cfg) - self.assertTrue(kernel_attr.weights_n_bits == 8) - self.assertTrue(nwc.get_attr_config(KERNEL).weights_n_bits == 4) - - pos_attr = copy.deepcopy(nwc.get_attr_config(0)) - nwc.set_attr_config(0, new_cfg) - self.assertTrue(pos_attr.weights_n_bits == 8) - self.assertTrue(nwc.get_attr_config(0).weights_n_bits == 4) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests_pytest/_fw_tests_common_base/fusing/base_graph_with_fusing_metadata_test.py b/tests_pytest/_fw_tests_common_base/fusing/base_graph_with_fusing_metadata_test.py index 7e685b270..730313a66 100644 --- a/tests_pytest/_fw_tests_common_base/fusing/base_graph_with_fusing_metadata_test.py +++ b/tests_pytest/_fw_tests_common_base/fusing/base_graph_with_fusing_metadata_test.py @@ -113,10 +113,6 @@ def test_disable_act_quantization(self, graph_with_fusion_metadata: Graph): """Tests that the correct nodes have activation quantization disabled after calling _disable_nodes_activation_quantization. """ - for node in graph_with_fusion_metadata.nodes: - for qc in node.candidates_quantization_cfg: - qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.QUANT - graph_with_fusion_metadata.override_fused_node_activation_quantization_candidates() disabled_nodes = [ node.name for node in graph_with_fusion_metadata.nodes diff --git a/tests_pytest/_test_util/graph_builder_utils.py b/tests_pytest/_test_util/graph_builder_utils.py index 215b33c46..9f77ca589 100644 --- a/tests_pytest/_test_util/graph_builder_utils.py +++ b/tests_pytest/_test_util/graph_builder_utils.py @@ -130,14 +130,16 @@ def build_nbits_qc(a_nbits=8, a_enable=True, w_attr=None, pos_attr=(32, False, ( # we generate q configs via constructors to follow the real code as closely as reasonably possible. # verify that we actually got the configurations we want - assert qc.activation_quantization_cfg.activation_n_bits == a_nbits assert qc.activation_quantization_cfg.enable_activation_quantization is a_enable + if a_enable: + assert qc.activation_quantization_cfg.activation_n_bits == a_nbits for k, v in w_attr.items(): # get_attr_config accepts canonical attr names - assert qc.weights_quantization_cfg.get_attr_config(k).weights_n_bits == v[0] assert qc.weights_quantization_cfg.get_attr_config(k).enable_weights_quantization == v[1] + if v[1]: + assert qc.weights_quantization_cfg.get_attr_config(k).weights_n_bits == v[0] for pos in pos_attr[2]: - assert qc.weights_quantization_cfg.get_attr_config(pos).weights_n_bits == pos_attr[0] assert qc.weights_quantization_cfg.get_attr_config(pos).enable_weights_quantization == pos_attr[1] - + if pos_attr[1]: + assert qc.weights_quantization_cfg.get_attr_config(pos).weights_n_bits == pos_attr[0] return qc diff --git a/tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py b/tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py index 67fd35ac7..76cf4bfe7 100644 --- a/tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py +++ b/tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py @@ -20,7 +20,8 @@ from mct_quantizers import QuantizationMethod from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo -from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness, \ + OpQuantizationConfig from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode, NodeActivationQuantizationConfig from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \ CandidateNodeQuantizationConfig, NodeQuantizationConfig @@ -58,12 +59,12 @@ def build_mock_node(name, layer_class, w_cfgs): """ node = build_node(name, layer_class=layer_class) - def eq(self_, other): - return self_.activation_n_bits == other.activation_n_bits and self_._quant_mode == other.quant_mode - a_cfgs = [Mock(spec=NodeActivationQuantizationConfig, - quant_mode=Mock(), - activation_n_bits=b, - __eq__=eq) for b in [5, 6]] + a_cfgs = [NodeActivationQuantizationConfig(Mock(spec=OpQuantizationConfig, + activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + activation_n_bits=b, + enable_activation_quantization=True, + quantization_preserving=False, + signedness=Signedness.AUTO)) for b in [5, 6]] qcs = [CandidateNodeQuantizationConfig(a_cfg, w_cfg) for a_cfg, w_cfg in itertools.product(a_cfgs, w_cfgs)] diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/__init__.py b/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py b/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py deleted file mode 100644 index db48e2bbd..000000000 --- a/tests_pytest/common_tests/unit_tests/core/quantization/node_quantization_config/test_node_weights_quantization_config.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from typing import List -from unittest.mock import Mock - -import pytest - -from mct_quantizers import QuantizationMethod -from model_compression_toolkit.core.common.quantization.node_quantization_config import \ - NodeWeightsQuantizationConfig -from model_compression_toolkit.target_platform_capabilities import Signedness, OpQuantizationConfig -from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR -from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig - - -class TestPositionalWeightsAttrQuantizationConfig: - def _create_weights_attr_quantization_config(self, weights_n_bits: int) -> AttributeQuantizationConfig: - """ - Helper method to create a weights attribute quantization configuration. - - Args: - weights_n_bits (int): Number of bits to use for quantizing weights. - - Returns: - AttributeQuantizationConfig: Holds the quantization configuration of a weight attribute of a layer. - """ - weights_attr_config = AttributeQuantizationConfig( - weights_quantization_method=QuantizationMethod.POWER_OF_TWO, - weights_n_bits=weights_n_bits, - weights_per_channel_threshold=False, - enable_weights_quantization=True, - lut_values_bitwidth=None) - return weights_attr_config - - def _create_node_weights_op_cfg( - self, - pos_weight_attr: List[str], - pos_weight_attr_config: List[AttributeQuantizationConfig], - def_weight_attr_config: AttributeQuantizationConfig) -> OpQuantizationConfig: - """ - Helper method to create a Node Weights OpQuantizationConfig with a default weights - attribute config and a specific weight attribute. - - Args: - pos_weight_attr (List[str]): List of names for specific weight attributes. - pos_weight_attr_config (List[AttributeQuantizationConfig]): Corresponding list of quantization configs - for the specific attributes. - def_weight_attr_config (AttributeQuantizationConfig): Default quantization config for the weights. - - Returns: - OpQuantizationConfig: Class to configure the quantization parameters of an operator. - """ - attr_weights_configs_mapping = dict(zip(pos_weight_attr, pos_weight_attr_config)) - - op_cfg = OpQuantizationConfig( - default_weight_attr_config=def_weight_attr_config, - attr_weights_configs_mapping=attr_weights_configs_mapping, - activation_quantization_method=QuantizationMethod.POWER_OF_TWO, - activation_n_bits=8, - supported_input_activation_n_bits=8, - enable_activation_quantization=True, - quantization_preserving=True, - fixed_scale=None, - fixed_zero_point=None, - simd_size=None, - signedness=Signedness.AUTO - ) - return op_cfg - - def test_node_weights_quantization_config_op_cfg_mapping(self): - """ - Test case for verifying that the positional weight attribute is correctly mapped and - configured in the NodeWeightsQuantizationConfig. - """ - positional_weight_attr = 0 - weights_n_bits = 8 - pos_weights_n_bits = 16 - - def_weight_attr_config = self._create_weights_attr_quantization_config(weights_n_bits) - pos_weight_attr_config = self._create_weights_attr_quantization_config(pos_weights_n_bits) - - # Ensure the configs have different weights bit widths. - assert def_weight_attr_config.weights_n_bits != pos_weight_attr_config.weights_n_bits - - op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[POSITIONAL_ATTR], - pos_weight_attr_config=[pos_weight_attr_config], - def_weight_attr_config=def_weight_attr_config) - - # Check that positional weights attribute config differs from default config. - assert op_cfg.default_weight_attr_config.weights_n_bits != op_cfg.attr_weights_configs_mapping[ - POSITIONAL_ATTR].weights_n_bits - - weights_quant_cfg = NodeWeightsQuantizationConfig(op_cfg=op_cfg, - weights_channels_axis=Mock(), - node_attrs_list=[positional_weight_attr]) - - # Check if the positional weight attribute was properly assigned in the positional attributes configuration - # mapping. - assert weights_quant_cfg.pos_attributes_config_mapping[ - positional_weight_attr].weights_n_bits == pos_weight_attr_config.weights_n_bits - - # Test using the positional attribute as the key rather than POS_ATTR; this mismatch should cause - # NodeWeightsQuantizationConfig to fall back to the default weights attribute configuration instead of - # applying the specific one. - op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[str(positional_weight_attr)], - pos_weight_attr_config=[pos_weight_attr_config], - def_weight_attr_config=def_weight_attr_config) - - # Check that positional weights attribute config differs from default config. - assert op_cfg.default_weight_attr_config.weights_n_bits != op_cfg.attr_weights_configs_mapping[ - str(positional_weight_attr)].weights_n_bits - - weights_quant_cfg = NodeWeightsQuantizationConfig(op_cfg=op_cfg, - weights_channels_axis=Mock(), - node_attrs_list=[positional_weight_attr]) - - # Check if the positional weight attribute was properly assigned in the positional attributes configuration - # mapping. - assert weights_quant_cfg.pos_attributes_config_mapping[ - positional_weight_attr].weights_n_bits == def_weight_attr_config.weights_n_bits - - # Add a second positional attribute with a different config. - second_positional_weight_attr = POSITIONAL_ATTR + '_1' - second_pos_weights_n_bits = 32 - second_pos_weight_attr_config = self._create_weights_attr_quantization_config(second_pos_weights_n_bits) - - # Confirm all three configs have different bit widths. - assert pos_weight_attr_config.weights_n_bits != second_pos_weight_attr_config.weights_n_bits - - # Create op config with two positional attribute keys and their respective configs. - op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[POSITIONAL_ATTR, second_positional_weight_attr], - pos_weight_attr_config=[pos_weight_attr_config, - second_pos_weight_attr_config], - def_weight_attr_config=def_weight_attr_config) - - # Check the configs are correctly set and distinct from each other and from the default. - assert op_cfg.default_weight_attr_config.weights_n_bits != op_cfg.attr_weights_configs_mapping[ - str(POSITIONAL_ATTR)].weights_n_bits - assert op_cfg.default_weight_attr_config.weights_n_bits != op_cfg.attr_weights_configs_mapping[ - str(second_positional_weight_attr)].weights_n_bits - assert op_cfg.attr_weights_configs_mapping[ - str(POSITIONAL_ATTR)].weights_n_bits != op_cfg.attr_weights_configs_mapping[ - str(second_positional_weight_attr)].weights_n_bits - - # Expect ValueError: multiple matching keys found for positional weights attribute. - with pytest.raises(ValueError, match='Found multiple attribute in FQC OpConfig that are contained in the ' - 'attribute name \'0\'.Please fix the FQC attribute names mapping such ' - 'that each operator\'s attribute would have a unique matching name.'): - NodeWeightsQuantizationConfig(op_cfg=op_cfg, weights_channels_axis=Mock(), - node_attrs_list=[positional_weight_attr]) diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_configurations/__init__.py b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_configurations/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_configurations/test_node_quantization_config.py b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_configurations/test_node_quantization_config.py deleted file mode 100644 index 5d885b6ba..000000000 --- a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_configurations/test_node_quantization_config.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from unittest.mock import Mock - -import pytest -from mct_quantizers import QuantizationMethod - -from model_compression_toolkit.core.common.quantization.node_quantization_config import \ - NodeActivationQuantizationConfig, ActivationQuantizationMode -from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig - - -class TestActivationQParams: - - def _get_op_config(self, qe, qp): - return Mock(spec=OpQuantizationConfig, - default_weight_attr_config=None, - attr_weights_configs_mapping=None, - activation_quantization_method=QuantizationMethod.POWER_OF_TWO, - activation_n_bits=8, - supported_input_activation_n_bits=[8], - enable_activation_quantization=qe, - quantization_preserving=qp, - fixed_scale=None, - fixed_zero_point=None, - simd_size=32, - signedness=None) - - def test_quantization_mode(self): - with pytest.raises(ValueError): - NodeActivationQuantizationConfig(self._get_op_config(True, True)) - assert (NodeActivationQuantizationConfig(self._get_op_config(False, False)). - quant_mode == ActivationQuantizationMode.NO_QUANT) - assert (NodeActivationQuantizationConfig(self._get_op_config(True, False)). - quant_mode == ActivationQuantizationMode.QUANT) - assert (NodeActivationQuantizationConfig(self._get_op_config(False, True)). - quant_mode == ActivationQuantizationMode.PRESERVE_QUANT) diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py index 0a0292f67..e3f6c3ede 100644 --- a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py +++ b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py @@ -58,7 +58,7 @@ def build_node(self, name='node', q_mode=ActivationQuantizationMode.QUANT): node.is_fln_quantization.return_value = False activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=self.build_op_cfg()) - activation_quantization_cfg.quant_mode = q_mode + activation_quantization_cfg.set_quant_mode(q_mode) candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig) candidate_quantization_config.activation_quantization_cfg = activation_quantization_cfg diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/test_filter_nodes_candidates.py b/tests_pytest/common_tests/unit_tests/core/quantization/test_filter_nodes_candidates.py deleted file mode 100644 index d8fc64765..000000000 --- a/tests_pytest/common_tests/unit_tests/core/quantization/test_filter_nodes_candidates.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from copy import deepcopy - -import pytest -from unittest.mock import Mock - -from mct_quantizers import QuantizationMethod -from model_compression_toolkit.core.common import Graph -from model_compression_toolkit.core.common.graph.base_node import BaseNode -from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode, NodeActivationQuantizationConfig -from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig -from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_node_candidates -from model_compression_toolkit.constants import FLOAT_BITWIDTH -from mct_quantizers import QuantizationMethod - -def build_mock_node(name, layer_class, idx): - """ - Creates mock nodes representing a simple neural network structure. - """ - node = Mock(spec=BaseNode) - node.name = name - node.layer_class = layer_class - node.kernel_attr = "Dmy" - - if idx == 0: - node.is_no_quantization.return_value = True - node.is_weights_quantization_enabled.return_value = False - elif idx == 1: - node.is_no_quantization.return_value = True - node.is_weights_quantization_enabled.return_value = True - else: - node.is_no_quantization.return_value = False - node.is_weights_quantization_enabled.return_value = True - node.is_fln_no_quantization.return_value = True - - activation_quantization_cfg = Mock(spec=NodeActivationQuantizationConfig) - activation_quantization_cfg.quant_mode = Mock() - candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig) - candidate_quantization_config.activation_quantization_cfg = activation_quantization_cfg - candidate_quantization_config.weights_quantization_cfg = Mock() - activation_quantization_cfg.activation_n_bits = 16 - activation_quantization_cfg.activation_quantization_method = QuantizationMethod.SYMMETRIC - - node.candidates_quantization_cfg = [candidate_quantization_config] - - return node - -@pytest.mark.parametrize(("idx"), [ - 0, - 1, - 2, -]) -def test_filter_node_candidates(idx): - """ - Test the filter_node_candidates function for a graph with multiple nodes and configurations. - """ - ### Create Test Nodes - mock_nodes = [] - mock_nodes.append(build_mock_node(name='conv', layer_class='Conv2d', idx=idx)) - ### Create a mock graph - ### Note: Generate the graph first because fusing_info cannot be set without it. - ### In the following Mock, use wraps to mock everything except fusing_info. - real_graph = Graph("dummy", [], [], [], []) - - graph = Mock(spec=Graph, wraps=real_graph) - graph.nodes = mock_nodes - ### call override_fused_node_activation_quantization_candidates - graph.override_fused_node_activation_quantization_candidates() - - output_candidates = filter_node_candidates(graph.nodes[0]) - - if idx == 0 or idx == 1: - assert output_candidates[0].activation_quantization_cfg.activation_n_bits == FLOAT_BITWIDTH - assert output_candidates[0].activation_quantization_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO - else: - assert output_candidates[0].activation_quantization_cfg.activation_n_bits == 16 - assert output_candidates[0].activation_quantization_cfg.activation_quantization_method == QuantizationMethod.SYMMETRIC - \ No newline at end of file diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/test_node_activation_quantization_config.py b/tests_pytest/common_tests/unit_tests/core/quantization/test_node_activation_quantization_config.py new file mode 100644 index 000000000..656d0bb44 --- /dev/null +++ b/tests_pytest/common_tests/unit_tests/core/quantization/test_node_activation_quantization_config.py @@ -0,0 +1,116 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from unittest.mock import Mock + +import pytest + +from model_compression_toolkit.core.common.quantization.node_quantization_config import \ + NodeActivationQuantizationConfig, ActivationQuantizationMode +from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig + + +class TestNodeActivationConfig: + + def _get_op_config(self, qe, qp): + return Mock(spec=OpQuantizationConfig, + activation_quantization_method=Mock(), + activation_n_bits=5, + enable_activation_quantization=qe, + quantization_preserving=qp, + signedness=Mock()) + + def test_config(self): + with pytest.raises(ValueError, + match="can't have both enable_activation_quantization and quantization_preserving enabled"): + NodeActivationQuantizationConfig(self._get_op_config(True, True)) + + cfg = NodeActivationQuantizationConfig(self._get_op_config(False, False)) + assert cfg.quant_mode == ActivationQuantizationMode.NO_QUANT + self._assert_unset_acfg(cfg) + + op_cfg = self._get_op_config(True, False) + cfg = NodeActivationQuantizationConfig(op_cfg) + assert cfg.quant_mode == ActivationQuantizationMode.QUANT + assert cfg.activation_n_bits == 5 + assert cfg.activation_quantization_method == op_cfg.activation_quantization_method + assert cfg.signedness == op_cfg.signedness + + cfg = NodeActivationQuantizationConfig(self._get_op_config(False, True)) + assert cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT + self._assert_unset_acfg(cfg) + + @pytest.mark.parametrize('mode', [ActivationQuantizationMode.NO_QUANT, + ActivationQuantizationMode.PRESERVE_QUANT, + ActivationQuantizationMode.FLN_NO_QUANT]) + def test_set_quant_mode(self, mode): + cfg = NodeActivationQuantizationConfig(self._get_op_config(True, False)) + cfg.set_quant_mode(mode) + assert cfg.quant_mode == mode + # lose irrelevant config + self._assert_unset_acfg(cfg) + + # after losing the config cannot set quant back + for qmode in [ActivationQuantizationMode.QUANT, ActivationQuantizationMode.FLN_QUANT]: + with pytest.raises(ValueError, match=f'Cannot change quant_mode to {qmode.name} from {mode.name}'): + cfg.set_quant_mode(qmode) + + with pytest.raises(RuntimeError, match='quant_mode cannot be set directly'): + cfg.quant_mode = ActivationQuantizationMode.NO_QUANT + + def test_set_quant_config_attribute(self): + cfg = NodeActivationQuantizationConfig(self._get_op_config(True, False)) + + assert cfg.activation_n_bits == 5 + cfg.set_quant_config_attr('activation_n_bits', 4) + assert cfg.activation_n_bits == 4 + + with pytest.raises(AttributeError, + match='Parameter activation_M_bits could not be found in the node quantization config.'): + cfg.set_quant_config_attr('activation_M_bits', 8) + + # quant_mode has a special handling + cfg.set_quant_config_attr('quant_mode', ActivationQuantizationMode.FLN_QUANT) + assert cfg.quant_mode == ActivationQuantizationMode.FLN_QUANT + + cfg.set_quant_config_attr('quant_mode', ActivationQuantizationMode.PRESERVE_QUANT) + self._assert_unset_acfg(cfg) + + cfg.set_quant_config_attr('quant_mode', ActivationQuantizationMode.NO_QUANT) + self._assert_unset_acfg(cfg) + + with pytest.raises(ValueError, match=f'Cannot change quant_mode to QUANT from NO_QUANT.'): + cfg.set_quant_config_attr('quant_mode', ActivationQuantizationMode.QUANT) + + with pytest.raises(ValueError, match='Cannot set attribute activation_n_bits for activation with disabled ' + 'quantization'): + cfg.set_quant_config_attr('activation_n_bits', 5) + + @pytest.mark.parametrize('mode', [ActivationQuantizationMode.QUANT, ActivationQuantizationMode.FLN_QUANT]) + def test_set_quantization_params(self, mode): + cfg = NodeActivationQuantizationConfig(self._get_op_config(True, False)) + cfg.set_quant_mode(mode) + + params1 = {'foo': 5, 'bar': 10} + cfg.set_activation_quantization_param(params1) + assert cfg.activation_quantization_params == params1 + + params2 = {'baz': 42} + cfg.set_activation_quantization_param(params2) + assert cfg.activation_quantization_params == params2 + + def _assert_unset_acfg(self, cfg: NodeActivationQuantizationConfig): + assert cfg.activation_n_bits == 0 + assert cfg.activation_quantization_method is None + assert cfg.signedness is None diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/test_node_weights_quantization_config.py b/tests_pytest/common_tests/unit_tests/core/quantization/test_node_weights_quantization_config.py new file mode 100644 index 000000000..fb79f39e7 --- /dev/null +++ b/tests_pytest/common_tests/unit_tests/core/quantization/test_node_weights_quantization_config.py @@ -0,0 +1,347 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import copy +from typing import List +from unittest.mock import Mock + +import pytest + +from mct_quantizers import QuantizationMethod + +from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping +from model_compression_toolkit.core.common.quantization.node_quantization_config import \ + NodeWeightsQuantizationConfig, WeightsAttrQuantizationConfig +from model_compression_toolkit.target_platform_capabilities import Signedness, OpQuantizationConfig +from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR +from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig + + +class TestNodeWeightsAttrConfig: + @pytest.mark.parametrize('method, nbits, per_channel, enabled', [ + (QuantizationMethod.POWER_OF_TWO, 5, True, True), + (QuantizationMethod.SYMMETRIC, 7, False, True), + (QuantizationMethod.SYMMETRIC, 7, False, False), + ]) + def test_config(self, method, nbits, per_channel, enabled): + input_cfg = AttributeQuantizationConfig( + weights_quantization_method=method, + weights_n_bits=nbits, + weights_per_channel_threshold=per_channel, + enable_weights_quantization=enabled, + lut_values_bitwidth=None) + + cfg = WeightsAttrQuantizationConfig(input_cfg, weights_channels_axis=ChannelAxisMapping(2, 3)) + assert cfg.enable_weights_quantization == enabled + if enabled: + assert cfg.weights_quantization_method == method + assert cfg.weights_n_bits == nbits + assert cfg.weights_per_channel_threshold == per_channel + assert cfg.weights_channels_axis == ChannelAxisMapping(2, 3) + else: + assert_unset_attr_config(cfg) + + # disable quantization + cfg.disable_quantization() + assert cfg.enable_weights_quantization is False + assert_unset_attr_config(cfg) + + with pytest.raises(RuntimeError, match='enable_quantization should not be set directly'): + cfg.enable_weights_quantization = False + + def test_set_quantization_param(self): + input_cfg = AttributeQuantizationConfig(enable_weights_quantization=True) + cfg = WeightsAttrQuantizationConfig(input_cfg) + params1 = {'foo': 5, 'bar': 10} + cfg.set_weights_quantization_param(params1) + assert cfg.weights_quantization_params == params1 + + params2 = {'baz': 42} + cfg.set_weights_quantization_param(params2) + assert cfg.weights_quantization_params == params2 + + def test_unsupported_lut(self): + input_cfg = AttributeQuantizationConfig(enable_weights_quantization=True, lut_values_bitwidth=5) + with pytest.raises(ValueError, match='None-default lut_values_bitwidth in AttributeQuantizationConfig ' + 'is not supported.'): + WeightsAttrQuantizationConfig(input_cfg) + + +def assert_unset_attr_config(cfg: WeightsAttrQuantizationConfig): + assert cfg.weights_quantization_method is None + assert cfg.weights_n_bits == 0 + assert cfg.weights_per_channel_threshold is None + assert cfg.weights_channels_axis is None + + +class TestWeightsQuantizationConfig: + def _create_weights_attr_quantization_config(self, weights_n_bits: int) -> AttributeQuantizationConfig: + """ + Helper method to create a weights attribute quantization configuration. + + Args: + weights_n_bits (int): Number of bits to use for quantizing weights. + + Returns: + AttributeQuantizationConfig: Holds the quantization configuration of a weight attribute of a layer. + """ + weights_attr_config = AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.POWER_OF_TWO, + weights_n_bits=weights_n_bits, + weights_per_channel_threshold=False, + enable_weights_quantization=True, + lut_values_bitwidth=None) + return weights_attr_config + + def _create_node_weights_op_cfg( + self, + pos_weight_attr: List[str], + pos_weight_attr_config: List[AttributeQuantizationConfig], + def_weight_attr_config: AttributeQuantizationConfig) -> OpQuantizationConfig: + """ + Helper method to create a Node Weights OpQuantizationConfig with a default weights + attribute config and a specific weight attribute. + + Args: + pos_weight_attr (List[str]): List of names for specific weight attributes. + pos_weight_attr_config (List[AttributeQuantizationConfig]): Corresponding list of quantization configs + for the specific attributes. + def_weight_attr_config (AttributeQuantizationConfig): Default quantization config for the weights. + + Returns: + OpQuantizationConfig: Class to configure the quantization parameters of an operator. + """ + attr_weights_configs_mapping = dict(zip(pos_weight_attr, pos_weight_attr_config)) + + op_cfg = OpQuantizationConfig( + default_weight_attr_config=def_weight_attr_config, + attr_weights_configs_mapping=attr_weights_configs_mapping, + activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + supported_input_activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=True, + fixed_scale=None, + fixed_zero_point=None, + simd_size=None, + signedness=Signedness.AUTO + ) + return op_cfg + + def test_node_weights_quantization_config_op_cfg_mapping(self): + """ + Test case for verifying that the positional weight attribute is correctly mapped and + configured in the NodeWeightsQuantizationConfig. + """ + positional_weight_attr = 0 + weights_n_bits = 8 + pos_weights_n_bits = 16 + + def_weight_attr_config = self._create_weights_attr_quantization_config(weights_n_bits) + pos_weight_attr_config = self._create_weights_attr_quantization_config(pos_weights_n_bits) + + # Ensure the configs have different weights bit widths. + assert def_weight_attr_config.weights_n_bits != pos_weight_attr_config.weights_n_bits + + op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[POSITIONAL_ATTR], + pos_weight_attr_config=[pos_weight_attr_config], + def_weight_attr_config=def_weight_attr_config) + + # Check that positional weights attribute config differs from default config. + assert op_cfg.default_weight_attr_config.weights_n_bits != op_cfg.attr_weights_configs_mapping[ + POSITIONAL_ATTR].weights_n_bits + + weights_quant_cfg = NodeWeightsQuantizationConfig(op_cfg=op_cfg, + weights_channels_axis=Mock(), + node_attrs_list=[positional_weight_attr]) + + # Check if the positional weight attribute was properly assigned in the positional attributes configuration + # mapping. + assert weights_quant_cfg.pos_attributes_config_mapping[ + positional_weight_attr].weights_n_bits == pos_weight_attr_config.weights_n_bits + + # Test using the positional attribute as the key rather than POS_ATTR; this mismatch should cause + # NodeWeightsQuantizationConfig to fall back to the default weights attribute configuration instead of + # applying the specific one. + op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[str(positional_weight_attr)], + pos_weight_attr_config=[pos_weight_attr_config], + def_weight_attr_config=def_weight_attr_config) + + # Check that positional weights attribute config differs from default config. + assert op_cfg.default_weight_attr_config.weights_n_bits != op_cfg.attr_weights_configs_mapping[ + str(positional_weight_attr)].weights_n_bits + + weights_quant_cfg = NodeWeightsQuantizationConfig(op_cfg=op_cfg, + weights_channels_axis=Mock(), + node_attrs_list=[positional_weight_attr]) + + # Check if the positional weight attribute was properly assigned in the positional attributes configuration + # mapping. + assert weights_quant_cfg.pos_attributes_config_mapping[ + positional_weight_attr].weights_n_bits == def_weight_attr_config.weights_n_bits + + # Add a second positional attribute with a different config. + second_positional_weight_attr = POSITIONAL_ATTR + '_1' + second_pos_weights_n_bits = 32 + second_pos_weight_attr_config = self._create_weights_attr_quantization_config(second_pos_weights_n_bits) + + # Confirm all three configs have different bit widths. + assert pos_weight_attr_config.weights_n_bits != second_pos_weight_attr_config.weights_n_bits + + # Create op config with two positional attribute keys and their respective configs. + op_cfg = self._create_node_weights_op_cfg(pos_weight_attr=[POSITIONAL_ATTR, second_positional_weight_attr], + pos_weight_attr_config=[pos_weight_attr_config, + second_pos_weight_attr_config], + def_weight_attr_config=def_weight_attr_config) + + # Check the configs are correctly set and distinct from each other and from the default. + assert op_cfg.default_weight_attr_config.weights_n_bits != op_cfg.attr_weights_configs_mapping[ + str(POSITIONAL_ATTR)].weights_n_bits + assert op_cfg.default_weight_attr_config.weights_n_bits != op_cfg.attr_weights_configs_mapping[ + str(second_positional_weight_attr)].weights_n_bits + assert op_cfg.attr_weights_configs_mapping[ + str(POSITIONAL_ATTR)].weights_n_bits != op_cfg.attr_weights_configs_mapping[ + str(second_positional_weight_attr)].weights_n_bits + + # Expect ValueError: multiple matching keys found for positional weights attribute. + with pytest.raises(ValueError, match='Found multiple attribute in FQC OpConfig that are contained in the ' + 'attribute name \'0\'.Please fix the FQC attribute names mapping such ' + 'that each operator\'s attribute would have a unique matching name.'): + NodeWeightsQuantizationConfig(op_cfg=op_cfg, weights_channels_axis=Mock(), + node_attrs_list=[positional_weight_attr]) + + def _create_wcfg(self): + # include enabled and disabled attrs + # include a name identical to config keys, and an extended name + attr_weights_configs_mapping = {'foo': AttributeQuantizationConfig(enable_weights_quantization=True, + weights_n_bits=7), + 'bar': AttributeQuantizationConfig(enable_weights_quantization=False)} + default_weight_attr_config = AttributeQuantizationConfig(enable_weights_quantization=True, + weights_n_bits=5) + node_attrs_list = ['afooz', 'bar', 0, 1] + wcfg = NodeWeightsQuantizationConfig(Mock(spec=OpQuantizationConfig, + attr_weights_configs_mapping=attr_weights_configs_mapping, + default_weight_attr_config=default_weight_attr_config, + simd_size=None), + weights_channels_axis=ChannelAxisMapping(1, 2), + node_attrs_list=node_attrs_list) + return wcfg, node_attrs_list + + def test_has_get_set_weights_attr_config(self): + """ Test has_attr_config, get_attr_config and set_attr_config """ + wcfg, node_attrs_list = self._create_wcfg() + + for attr in node_attrs_list: + assert wcfg.has_attribute_config(attr) is True + assert wcfg.has_attribute_config('baz') is False + assert wcfg.has_attribute_config(2) is False + + assert wcfg.get_attr_config('foo').weights_n_bits == 7 + # get config should work by both long and short name + assert wcfg.get_attr_config('afooz') == wcfg.get_attr_config('foo') + assert wcfg.get_attr_config('bar').enable_weights_quantization is False + assert wcfg.get_attr_config(0).weights_n_bits == 5 + assert wcfg.get_attr_config(1).weights_n_bits == 5 + + new_cfg = Mock() + wcfg.set_attr_config('afooz', new_cfg) + assert wcfg.get_attr_config('foo') == new_cfg + + assert wcfg.get_attr_config('bar') != new_cfg + wcfg.set_attr_config('bar', new_cfg) + assert wcfg.get_attr_config('bar') == new_cfg + + assert wcfg.get_attr_config(1) != new_cfg + wcfg.set_attr_config(1, new_cfg) + assert wcfg.get_attr_config(1) == new_cfg + + # non-existing attrs + with pytest.raises(ValueError, match='Unknown weights attr foo'): + # set attr expects the full name + wcfg.set_attr_config('foo', new_cfg) + with pytest.raises(ValueError, match='Unknown weights attr 2'): + wcfg.set_attr_config(2, new_cfg) + + # non-existing attrs with force=True + wcfg.set_attr_config('baz', new_cfg, force=True) + assert wcfg.get_attr_config(1) == new_cfg + + wcfg.set_attr_config(2, new_cfg, force=True) + assert wcfg.get_attr_config(1) == new_cfg + + def test_set_quant_config_wcfg_level(self): + """ Test set_quant_config for attributes at the weight config level. """ + wcfg, _ = self._create_wcfg() + + assert wcfg.simd_size is None + wcfg.set_quant_config_attr('simd_size', 5) + assert wcfg.simd_size == 5 + + with pytest.raises(AttributeError): + wcfg.set_quant_config_attr('no_such_attr', 5) + + def test_set_quant_config_attr_level(self): + """ Test set_quant_config for attributes of weights attrs configs. """ + wcfg, _ = self._create_wcfg() + + wcfg.set_quant_config_attr('weights_n_bits', 4, attr_name='afooz') + assert wcfg.get_attr_config('afooz').weights_n_bits == 4 + + assert wcfg.get_attr_config(0).weights_n_bits == 5 + wcfg.set_quant_config_attr('weights_n_bits', 7, attr_name=1) + assert wcfg.get_attr_config(1).weights_n_bits == 7 + # 0 is not affected + assert wcfg.get_attr_config(0).weights_n_bits == 5 + + # enable_weights_quantization has a special handling: + foo_cfg = copy.deepcopy(wcfg.get_attr_config('afooz')) + # True with already enabled quantization has no effect (but doesn't fail) + wcfg.set_quant_config_attr('enable_weights_quantization', True, attr_name='afooz') + assert wcfg.get_attr_config('afooz') == foo_cfg + # False should reset all attrs + wcfg.set_quant_config_attr('enable_weights_quantization', False, attr_name='afooz') + assert_unset_attr_config(wcfg.get_attr_config('afooz')) + # False can be set again (check that doesn't crash) + wcfg.set_quant_config_attr('enable_weights_quantization', False, attr_name='afooz') + + def test_set_quant_config_attr_level_errors(self): + """ Test set_quant_config for attributes of weights attrs configs. """ + wcfg, _ = self._create_wcfg() + + for attr in ['baz', 2]: + with pytest.raises(ValueError, match=f'Weights attribute {attr} could not be found'): + wcfg.set_quant_config_attr('weights_n_bits', 5, attr_name=attr) + + with pytest.raises(AttributeError, match='Parameter no_such_attr could not be found in the quantization config ' + 'of weights attribute 1'): + wcfg.set_quant_config_attr('no_such_attr', 5, attr_name=1) + + # disabled quantization cannot be turned on (enable_weights_quantization has a special handling) + with pytest.raises(ValueError, match=f'Cannot enable quantization for attr bar with disabled quantization.'): + wcfg.set_quant_config_attr('enable_weights_quantization', True, attr_name='bar') + # no other attr can be set for disabled quantization + with pytest.raises(ValueError, match=f'Cannot set param weights_n_bits for attr bar with disabled quantization.'): + wcfg.set_quant_config_attr('weights_n_bits', 5, attr_name='bar') + + def test_disable_all(self): + wcfg, node_attrs_list = self._create_wcfg() + wcfg.disable_all_weights_quantization() + for attr in node_attrs_list: + assert_unset_attr_config(wcfg.get_attr_config(attr)) + + def test_get_all(self): + wcfg, node_attrs_list = self._create_wcfg() + assert sorted(wcfg.all_weight_attrs, key=lambda v: str(v)) == sorted(node_attrs_list, key=lambda v: str(v)) + cfgs = wcfg.get_all_weight_attrs_configs() + assert cfgs == {attr: wcfg.get_attr_config(attr) for attr in node_attrs_list}