diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index e1df6e4e8..d93eb5acb 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -98,6 +98,9 @@ def __init__(self, op_cfg: OpQuantizationConfig): self.activation_quantization_params = {} # TODO: computed by compute_activation_bias_correction. Probably shouldnt be here. self.activation_bias_correction_term = None + # Z-threshold is a global param from QuantizationConfig, however it can be overridden per node by NetworkEditor. + # Since activation qparams are re-computed in several places, it's easier to keep it here and update it once. + self.z_threshold = None @property def enable_activation_quantization(self): diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py index 94a71f834..3ed8b7b70 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py @@ -47,9 +47,12 @@ def compute_activation_qparams(quant_cfg: QuantizationConfig, node_activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded()) # Extract and filter histogram data from the statistics container. + z_threshold = quant_cfg.z_threshold + if node_activation_quant_cfg.z_threshold is not None: + z_threshold = node_activation_quant_cfg.z_threshold bins_values, bins_counts = _get_histogram_data(out_stats_container, activation_error_method=quant_cfg.activation_error_method, - z_threshold=quant_cfg.z_threshold) + z_threshold=z_threshold) # Retrieve the minimum and maximum values from the statistics container. min_value, max_value = out_stats_container.get_min_max_values() diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_selection/test_qarams_activations_computation.py b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_selection/test_qarams_activations_computation.py index 3accd5e47..4c84cf091 100644 --- a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_selection/test_qarams_activations_computation.py +++ b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_selection/test_qarams_activations_computation.py @@ -55,7 +55,7 @@ def _create_activation_quant_cfg(self, quant_method, n_bits=8, signedness=Signed op_cfg = OpQuantizationConfig( default_weight_attr_config=AttributeQuantizationConfig(), attr_weights_configs_mapping={}, - activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + activation_quantization_method=quant_method, activation_n_bits=n_bits, supported_input_activation_n_bits=n_bits, enable_activation_quantization=True, @@ -66,7 +66,6 @@ def _create_activation_quant_cfg(self, quant_method, n_bits=8, signedness=Signed signedness=signedness ) activation_quant_cfg = NodeActivationQuantizationConfig(op_cfg) - activation_quant_cfg.activation_quantization_method = quant_method return activation_quant_cfg def test_get_histogram_data_error_method(self): @@ -190,3 +189,25 @@ def test_get_activations_qparams(self, quant_method, activation_quantization_par activation_quant_cfg = self._create_activation_quant_cfg(quant_method, n_bits=2) result = compute_activation_qparams(QuantizationConfig(), activation_quant_cfg, nodes_prior_info, stats) assert result == expected_result + + def test_overriden_z_thresh(self, mocker): + """ Check that correct z-threshold is passed to _get_histogram_data """ + spy = mocker.patch('model_compression_toolkit.core.common.quantization.quantization_params_generation.' + 'qparams_activations_computation._get_histogram_data', + return_value=(np.array([1, 2]), np.array([100]))) + stat_collector_mock = mocker.Mock(spec_set=StatsCollector, get_min_max_values=lambda: (-1, 1)) + + acfg = self._create_activation_quant_cfg(QuantizationMethod.POWER_OF_TWO) + assert acfg.z_threshold is None + compute_activation_qparams(QuantizationConfig(z_threshold=100), acfg, NodePriorInfo(), stat_collector_mock) + # z-threshold from quant config is used + assert spy.called_once_with(stat_collector_mock, + activation_error_method=QuantizationMethod.POWER_OF_TWO, + z_threshold=100) + + # z-threshold from the node should be used + acfg.z_threshold = 5 + compute_activation_qparams(QuantizationConfig(z_threshold=100), acfg, NodePriorInfo(), stat_collector_mock) + assert spy.called_with(stat_collector_mock, + activation_error_method=QuantizationMethod.POWER_OF_TWO, + z_threshold=5)