Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def __init__(self, op_cfg: OpQuantizationConfig):
self.activation_quantization_params = {}
# TODO: computed by compute_activation_bias_correction. Probably shouldnt be here.
self.activation_bias_correction_term = None
# Z-threshold is a global param from QuantizationConfig, however it can be overridden per node by NetworkEditor.
# Since activation qparams are re-computed in several places, it's easier to keep it here and update it once.
self.z_threshold = None

@property
def enable_activation_quantization(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def compute_activation_qparams(quant_cfg: QuantizationConfig,
node_activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())

# Extract and filter histogram data from the statistics container.
z_threshold = quant_cfg.z_threshold
if node_activation_quant_cfg.z_threshold is not None:
z_threshold = node_activation_quant_cfg.z_threshold
bins_values, bins_counts = _get_histogram_data(out_stats_container,
activation_error_method=quant_cfg.activation_error_method,
z_threshold=quant_cfg.z_threshold)
z_threshold=z_threshold)

# Retrieve the minimum and maximum values from the statistics container.
min_value, max_value = out_stats_container.get_min_max_values()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _create_activation_quant_cfg(self, quant_method, n_bits=8, signedness=Signed
op_cfg = OpQuantizationConfig(
default_weight_attr_config=AttributeQuantizationConfig(),
attr_weights_configs_mapping={},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_quantization_method=quant_method,
activation_n_bits=n_bits,
supported_input_activation_n_bits=n_bits,
enable_activation_quantization=True,
Expand All @@ -66,7 +66,6 @@ def _create_activation_quant_cfg(self, quant_method, n_bits=8, signedness=Signed
signedness=signedness
)
activation_quant_cfg = NodeActivationQuantizationConfig(op_cfg)
activation_quant_cfg.activation_quantization_method = quant_method
return activation_quant_cfg

def test_get_histogram_data_error_method(self):
Expand Down Expand Up @@ -190,3 +189,25 @@ def test_get_activations_qparams(self, quant_method, activation_quantization_par
activation_quant_cfg = self._create_activation_quant_cfg(quant_method, n_bits=2)
result = compute_activation_qparams(QuantizationConfig(), activation_quant_cfg, nodes_prior_info, stats)
assert result == expected_result

def test_overriden_z_thresh(self, mocker):
""" Check that correct z-threshold is passed to _get_histogram_data """
spy = mocker.patch('model_compression_toolkit.core.common.quantization.quantization_params_generation.'
'qparams_activations_computation._get_histogram_data',
return_value=(np.array([1, 2]), np.array([100])))
stat_collector_mock = mocker.Mock(spec_set=StatsCollector, get_min_max_values=lambda: (-1, 1))

acfg = self._create_activation_quant_cfg(QuantizationMethod.POWER_OF_TWO)
assert acfg.z_threshold is None
compute_activation_qparams(QuantizationConfig(z_threshold=100), acfg, NodePriorInfo(), stat_collector_mock)
# z-threshold from quant config is used
assert spy.called_once_with(stat_collector_mock,
activation_error_method=QuantizationMethod.POWER_OF_TWO,
z_threshold=100)

# z-threshold from the node should be used
acfg.z_threshold = 5
compute_activation_qparams(QuantizationConfig(z_threshold=100), acfg, NodePriorInfo(), stat_collector_mock)
assert spy.called_with(stat_collector_mock,
activation_error_method=QuantizationMethod.POWER_OF_TWO,
z_threshold=5)