Skip to content

Commit 3fdc043

Browse files
authored
restore an ability to set z-threshold per node using network editor (SonySemiconductorSolutions#1491)
1 parent fbd260e commit 3fdc043

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def __init__(self, op_cfg: OpQuantizationConfig):
9898
self.activation_quantization_params = {}
9999
# TODO: computed by compute_activation_bias_correction. Probably shouldnt be here.
100100
self.activation_bias_correction_term = None
101+
# Z-threshold is a global param from QuantizationConfig, however it can be overridden per node by NetworkEditor.
102+
# Since activation qparams are re-computed in several places, it's easier to keep it here and update it once.
103+
self.z_threshold = None
101104

102105
@property
103106
def enable_activation_quantization(self):

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ def compute_activation_qparams(quant_cfg: QuantizationConfig,
4747
node_activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
4848

4949
# Extract and filter histogram data from the statistics container.
50+
z_threshold = quant_cfg.z_threshold
51+
if node_activation_quant_cfg.z_threshold is not None:
52+
z_threshold = node_activation_quant_cfg.z_threshold
5053
bins_values, bins_counts = _get_histogram_data(out_stats_container,
5154
activation_error_method=quant_cfg.activation_error_method,
52-
z_threshold=quant_cfg.z_threshold)
55+
z_threshold=z_threshold)
5356

5457
# Retrieve the minimum and maximum values from the statistics container.
5558
min_value, max_value = out_stats_container.get_min_max_values()

tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_selection/test_qarams_activations_computation.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _create_activation_quant_cfg(self, quant_method, n_bits=8, signedness=Signed
5555
op_cfg = OpQuantizationConfig(
5656
default_weight_attr_config=AttributeQuantizationConfig(),
5757
attr_weights_configs_mapping={},
58-
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
58+
activation_quantization_method=quant_method,
5959
activation_n_bits=n_bits,
6060
supported_input_activation_n_bits=n_bits,
6161
enable_activation_quantization=True,
@@ -66,7 +66,6 @@ def _create_activation_quant_cfg(self, quant_method, n_bits=8, signedness=Signed
6666
signedness=signedness
6767
)
6868
activation_quant_cfg = NodeActivationQuantizationConfig(op_cfg)
69-
activation_quant_cfg.activation_quantization_method = quant_method
7069
return activation_quant_cfg
7170

7271
def test_get_histogram_data_error_method(self):
@@ -190,3 +189,25 @@ def test_get_activations_qparams(self, quant_method, activation_quantization_par
190189
activation_quant_cfg = self._create_activation_quant_cfg(quant_method, n_bits=2)
191190
result = compute_activation_qparams(QuantizationConfig(), activation_quant_cfg, nodes_prior_info, stats)
192191
assert result == expected_result
192+
193+
def test_overriden_z_thresh(self, mocker):
194+
""" Check that correct z-threshold is passed to _get_histogram_data """
195+
spy = mocker.patch('model_compression_toolkit.core.common.quantization.quantization_params_generation.'
196+
'qparams_activations_computation._get_histogram_data',
197+
return_value=(np.array([1, 2]), np.array([100])))
198+
stat_collector_mock = mocker.Mock(spec_set=StatsCollector, get_min_max_values=lambda: (-1, 1))
199+
200+
acfg = self._create_activation_quant_cfg(QuantizationMethod.POWER_OF_TWO)
201+
assert acfg.z_threshold is None
202+
compute_activation_qparams(QuantizationConfig(z_threshold=100), acfg, NodePriorInfo(), stat_collector_mock)
203+
# z-threshold from quant config is used
204+
assert spy.called_once_with(stat_collector_mock,
205+
activation_error_method=QuantizationMethod.POWER_OF_TWO,
206+
z_threshold=100)
207+
208+
# z-threshold from the node should be used
209+
acfg.z_threshold = 5
210+
compute_activation_qparams(QuantizationConfig(z_threshold=100), acfg, NodePriorInfo(), stat_collector_mock)
211+
assert spy.called_with(stat_collector_mock,
212+
activation_error_method=QuantizationMethod.POWER_OF_TWO,
213+
z_threshold=5)

0 commit comments

Comments
 (0)