@@ -55,7 +55,7 @@ def _create_activation_quant_cfg(self, quant_method, n_bits=8, signedness=Signed
5555 op_cfg = OpQuantizationConfig (
5656 default_weight_attr_config = AttributeQuantizationConfig (),
5757 attr_weights_configs_mapping = {},
58- activation_quantization_method = QuantizationMethod . POWER_OF_TWO ,
58+ activation_quantization_method = quant_method ,
5959 activation_n_bits = n_bits ,
6060 supported_input_activation_n_bits = n_bits ,
6161 enable_activation_quantization = True ,
@@ -66,7 +66,6 @@ def _create_activation_quant_cfg(self, quant_method, n_bits=8, signedness=Signed
6666 signedness = signedness
6767 )
6868 activation_quant_cfg = NodeActivationQuantizationConfig (op_cfg )
69- activation_quant_cfg .activation_quantization_method = quant_method
7069 return activation_quant_cfg
7170
7271 def test_get_histogram_data_error_method (self ):
@@ -190,3 +189,25 @@ def test_get_activations_qparams(self, quant_method, activation_quantization_par
190189 activation_quant_cfg = self ._create_activation_quant_cfg (quant_method , n_bits = 2 )
191190 result = compute_activation_qparams (QuantizationConfig (), activation_quant_cfg , nodes_prior_info , stats )
192191 assert result == expected_result
192+
193+ def test_overriden_z_thresh (self , mocker ):
194+ """ Check that correct z-threshold is passed to _get_histogram_data """
195+ spy = mocker .patch ('model_compression_toolkit.core.common.quantization.quantization_params_generation.'
196+ 'qparams_activations_computation._get_histogram_data' ,
197+ return_value = (np .array ([1 , 2 ]), np .array ([100 ])))
198+ stat_collector_mock = mocker .Mock (spec_set = StatsCollector , get_min_max_values = lambda : (- 1 , 1 ))
199+
200+ acfg = self ._create_activation_quant_cfg (QuantizationMethod .POWER_OF_TWO )
201+ assert acfg .z_threshold is None
202+ compute_activation_qparams (QuantizationConfig (z_threshold = 100 ), acfg , NodePriorInfo (), stat_collector_mock )
203+ # z-threshold from quant config is used
204+ assert spy .called_once_with (stat_collector_mock ,
205+ activation_error_method = QuantizationMethod .POWER_OF_TWO ,
206+ z_threshold = 100 )
207+
208+ # z-threshold from the node should be used
209+ acfg .z_threshold = 5
210+ compute_activation_qparams (QuantizationConfig (z_threshold = 100 ), acfg , NodePriorInfo (), stat_collector_mock )
211+ assert spy .called_with (stat_collector_mock ,
212+ activation_error_method = QuantizationMethod .POWER_OF_TWO ,
213+ z_threshold = 5 )
0 commit comments