|
25 | 25 | from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod |
26 | 26 |
|
27 | 27 |
|
| 28 | +def compute_activation_qparams(activation_quant_cfg: NodeActivationQuantizationConfig, |
| 29 | + node_prior_info: NodePriorInfo, |
| 30 | + out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]: |
| 31 | + """ |
| 32 | + Compute the activations params for a given node in a graph according to a params function. |
| 33 | +
|
| 34 | + Args: |
| 35 | + activation_quant_cfg: node's activation quantization configuration. |
| 36 | + node_prior_info: Prior info collected for the node that is being quantized. |
| 37 | + out_stats_container: Tensor containing output statistics of the node. |
| 38 | +
|
| 39 | + Returns: |
| 40 | + The computed activation quantization params. |
| 41 | + """ |
| 42 | + activation_quantization_params_fn = _get_activation_quantization_params_fn( |
| 43 | + activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded()) |
| 44 | + |
| 45 | + # Extract and filter histogram data from the statistics container. |
| 46 | + bins_values, bins_counts = _get_histogram_data(activation_quant_cfg, out_stats_container) |
| 47 | + |
| 48 | + # Retrieve the minimum and maximum values from the statistics container. |
| 49 | + min_value, max_value = out_stats_container.get_min_max_values() |
| 50 | + |
| 51 | + # Determine if the activations should be considered signed. |
| 52 | + signed = _determine_signedness(activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts) |
| 53 | + |
| 54 | + # Compute and return the activation quantization parameters. |
| 55 | + return activation_quantization_params_fn( |
| 56 | + bins_values, |
| 57 | + bins_counts, |
| 58 | + activation_quant_cfg.l_p_value, |
| 59 | + activation_quant_cfg.activation_n_bits, |
| 60 | + min_value, |
| 61 | + max_value, |
| 62 | + min_threshold=activation_quant_cfg.min_threshold, |
| 63 | + quant_error_method=activation_quant_cfg.activation_error_method, |
| 64 | + is_signed=signed |
| 65 | + ) |
| 66 | + |
| 67 | + |
28 | 68 | def _get_histogram_data( |
29 | 69 | activation_quant_cfg: NodeActivationQuantizationConfig, |
30 | 70 | out_stats_container: BaseStatsCollector |
@@ -85,46 +125,6 @@ def _determine_signedness( |
85 | 125 | return np.any(bins_values[:-1][bins_counts > 0] < 0) |
86 | 126 |
|
87 | 127 |
|
88 | | -def get_activations_qparams(activation_quant_cfg: NodeActivationQuantizationConfig, |
89 | | - node_prior_info: NodePriorInfo, |
90 | | - out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]: |
91 | | - """ |
92 | | - Compute the activations params for a given node in a graph according to a params function. |
93 | | -
|
94 | | - Args: |
95 | | - activation_quant_cfg: node's activation quantization configuration. |
96 | | - node_prior_info: Prior info collected for the node that is being quantized. |
97 | | - out_stats_container: Tensor containing output statistics of the node. |
98 | | -
|
99 | | - Returns: |
100 | | - The computed activation quantization params. |
101 | | - """ |
102 | | - activation_quantization_params_fn = _get_activation_quantization_params_fn( |
103 | | - activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded()) |
104 | | - |
105 | | - # Extract and filter histogram data from the statistics container. |
106 | | - bins_values, bins_counts = _get_histogram_data(activation_quant_cfg, out_stats_container) |
107 | | - |
108 | | - # Retrieve the minimum and maximum values from the statistics container. |
109 | | - min_value, max_value = out_stats_container.get_min_max_values() |
110 | | - |
111 | | - # Determine if the activations should be considered signed. |
112 | | - signed = _determine_signedness(activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts) |
113 | | - |
114 | | - # Compute and return the activation quantization parameters. |
115 | | - return activation_quantization_params_fn( |
116 | | - bins_values, |
117 | | - bins_counts, |
118 | | - activation_quant_cfg.l_p_value, |
119 | | - activation_quant_cfg.activation_n_bits, |
120 | | - min_value, |
121 | | - max_value, |
122 | | - min_threshold=activation_quant_cfg.min_threshold, |
123 | | - quant_error_method=activation_quant_cfg.activation_error_method, |
124 | | - is_signed=signed |
125 | | - ) |
126 | | - |
127 | | - |
128 | 128 | _activation_quant_params_fns = { |
129 | 129 | QuantizationMethod.POWER_OF_TWO: qpg.power_of_two_selection_histogram, |
130 | 130 | QuantizationMethod.SYMMETRIC: qpg.symmetric_selection_histogram, |
|
0 commit comments