Skip to content

Commit 3c9f296

Browse files
committed
Merge remote-tracking branch 'origin/main' into remove_default_output_axis_from_fw_info
2 parents 81ed69c + 64e3adb commit 3c9f296

34 files changed

Lines changed: 185 additions & 412 deletions

File tree

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 22 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
1919
from model_compression_toolkit.logger import Logger
2020

21-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
2221
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
2322
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
2423
AttributeQuantizationConfig, OpQuantizationConfig
@@ -41,6 +40,7 @@ class ActivationQuantizationMode(Enum):
4140
NO_QUANT = auto()
4241
FLN_NO_QUANT = auto()
4342

43+
4444
class BaseNodeQuantizationConfig(object):
4545
"""
4646
Base class for node quantization configuration
@@ -59,12 +59,11 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val
5959
kwargs: A dictionary with additional key arguments.
6060
6161
"""
62-
6362
if hasattr(self, config_parameter_name):
6463
setattr(self, config_parameter_name, config_parameter_value)
6564
else:
66-
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
67-
f"was not updated!")
65+
raise AttributeError(
66+
f"Parameter {config_parameter_name} could not be found in the node quantization config.")
6867

6968
def __repr__(self) -> str:
7069
"""
@@ -97,37 +96,9 @@ def __init__(self, op_cfg: OpQuantizationConfig):
9796
self.signedness = op_cfg.signedness
9897

9998
self.activation_quantization_params = {}
100-
# TODO irena: computed by compute_activation_bias_correction. shouldnt really be here
99+
# TODO: computed by compute_activation_bias_correction. Probably shouldnt be here.
101100
self.activation_bias_correction_term = None
102101

103-
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
104-
self.activation_error_method = None
105-
self.relu_bound_to_power_of_2 = None
106-
self.activation_channel_equalization = None
107-
self.input_scaling = None
108-
self.min_threshold = None
109-
self.l_p_value = None
110-
self.shift_negative_activation_correction = None
111-
self.z_threshold = None
112-
self.shift_negative_ratio = None
113-
self.shift_negative_threshold_recalculation = None
114-
self.concat_threshold_update = None
115-
116-
def set_qc(self, qc: QuantizationConfig):
117-
""" TODO irena: temporary keep all the attributes as before not to break all code at once.
118-
Eventually all of them should be removed from here. """
119-
self.activation_error_method = qc.activation_error_method
120-
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
121-
self.activation_channel_equalization = qc.activation_channel_equalization
122-
self.input_scaling = qc.input_scaling
123-
self.min_threshold = qc.min_threshold
124-
self.l_p_value = qc.l_p_value
125-
self.shift_negative_activation_correction = qc.shift_negative_activation_correction
126-
self.z_threshold = qc.z_threshold
127-
self.shift_negative_ratio = qc.shift_negative_ratio
128-
self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
129-
self.concat_threshold_update = qc.concat_threshold_update
130-
131102
@property
132103
def enable_activation_quantization(self):
133104
return self.quant_mode == ActivationQuantizationMode.QUANT
@@ -165,32 +136,16 @@ def __eq__(self, other: Any) -> bool:
165136
if not isinstance(other, NodeActivationQuantizationConfig):
166137
return False # pragma: no cover
167138

168-
return self.activation_error_method == other.activation_error_method and \
169-
self.activation_quantization_method == other.activation_quantization_method and \
139+
return self.activation_quantization_method == other.activation_quantization_method and \
170140
self.activation_n_bits == other.activation_n_bits and \
171141
self.quant_mode == other.quant_mode and \
172-
self.activation_channel_equalization == other.activation_channel_equalization and \
173-
self.input_scaling == other.input_scaling and \
174-
self.min_threshold == other.min_threshold and \
175-
self.l_p_value == other.l_p_value and \
176-
self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
177-
self.z_threshold == other.z_threshold and \
178-
self.shift_negative_ratio == other.shift_negative_ratio and \
179-
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
142+
self.signedness == other.signedness
180143

181144
def __hash__(self):
182-
return hash((self.activation_error_method,
183-
self.activation_quantization_method,
145+
return hash((self.activation_quantization_method,
184146
self.activation_n_bits,
185147
self.quant_mode,
186-
self.activation_channel_equalization,
187-
self.input_scaling,
188-
self.min_threshold,
189-
self.l_p_value,
190-
self.shift_negative_activation_correction,
191-
self.z_threshold,
192-
self.shift_negative_ratio,
193-
self.shift_negative_threshold_recalculation))
148+
self.signedness))
194149

195150

196151
class WeightsAttrQuantizationConfig:
@@ -211,16 +166,8 @@ def __init__(self,
211166
self.weights_n_bits = weights_attr_cfg.weights_n_bits
212167
self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold
213168
self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization
214-
self.weights_quantization_params = {}
215169

216-
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
217-
self.weights_error_method = None
218-
self.l_p_value = None
219-
220-
def set_qc(self, qc: QuantizationConfig):
221-
# TODO irena: temporary keep the fields to not break everything at once.
222-
self.weights_error_method = qc.weights_error_method
223-
self.l_p_value = qc.l_p_value
170+
self.weights_quantization_params = {}
224171

225172
def set_weights_quantization_param(self,
226173
weights_params: dict):
@@ -252,18 +199,14 @@ def __eq__(self, other: Any) -> bool:
252199
self.weights_quantization_method == other.weights_quantization_method and \
253200
self.weights_n_bits == other.weights_n_bits and \
254201
self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
255-
self.enable_weights_quantization == other.enable_weights_quantization and \
256-
self.weights_error_method == other.weights_error_method and \
257-
self.l_p_value == other.l_p_value
202+
self.enable_weights_quantization == other.enable_weights_quantization
258203

259204
def __hash__(self):
260205
return hash((self.weights_channels_axis,
261-
self.weights_error_method,
262206
self.weights_quantization_method,
263207
self.weights_n_bits,
264208
self.weights_per_channel_threshold,
265-
self.enable_weights_quantization,
266-
self.l_p_value))
209+
self.enable_weights_quantization))
267210

268211

269212
class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
@@ -330,16 +273,14 @@ def __init__(self,
330273

331274
self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
332275
weights_channels_axis=weights_channels_axis)
333-
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
334-
self.min_threshold = None
276+
# TODO this is set by batch norm reconstruction substitution when folded batch norms are added back, to mark
277+
# the nodes that the correction should be applied to (for some nodes it gets disabled) and BNs removed.
278+
# The actual correction is only computed when it's applied in ptq, so it seems that both substitutions could
279+
# be unified, and no info need to pass between.
335280
self.weights_second_moment_correction = None
336-
self.weights_bias_correction = None
337-
338-
def set_qc(self, qc: QuantizationConfig):
339-
# TODO irena: temporary keep the fields to not break everything at once.
340-
self.min_threshold = qc.min_threshold
341-
self.weights_second_moment_correction = qc.weights_second_moment_correction
342-
self.weights_bias_correction = qc.weights_bias_correction
281+
# TODO: computed corrected bias is injected to the node config. Probably shouldn't be here. Also it can be
282+
# computed on the final config, instead of all candidates and then there is no need to save it at all.
283+
self.bias_corrected = None
343284

344285
def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
345286
"""
@@ -476,8 +417,8 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val
476417
if hasattr(attr_cfg, config_parameter_name):
477418
setattr(attr_cfg, config_parameter_name, config_parameter_value)
478419
else:
479-
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
480-
f"weights attribute {attr_name} and was not updated!")
420+
raise AttributeError(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
421+
f"weights attribute {attr_name}.")
481422
else: # pragma: no cover
482423
Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
483424

@@ -494,10 +435,7 @@ def __eq__(self, other: Any) -> bool:
494435
if not isinstance(other, NodeWeightsQuantizationConfig):
495436
return False # pragma: no cover
496437

497-
return self.min_threshold == other.min_threshold and \
498-
self.simd_size == other.simd_size and \
499-
self.weights_second_moment_correction == other.weights_second_moment_correction and \
500-
self.weights_bias_correction == other.weights_bias_correction and \
438+
return self.simd_size == other.simd_size and \
501439
self.attributes_config_mapping.keys() == other.attributes_config_mapping.keys() and \
502440
all([self.attributes_config_mapping[k] == other.attributes_config_mapping[k]
503441
for k in self.attributes_config_mapping.keys()]) and \
@@ -506,9 +444,6 @@ def __eq__(self, other: Any) -> bool:
506444
for k in self.pos_attributes_config_mapping.keys()])
507445

508446
def __hash__(self):
509-
return hash((self.min_threshold,
510-
self.simd_size,
511-
self.weights_second_moment_correction,
512-
self.weights_bias_correction,
447+
return hash((self.simd_size,
513448
frozenset(self.attributes_config_mapping),
514449
frozenset(self.pos_attributes_config_mapping)))

model_compression_toolkit/core/common/quantization/quantization_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ class QuantizationConfig:
9090
shift_negative_activation_correction: bool = True
9191
activation_channel_equalization: bool = False
9292
z_threshold: float = math.inf
93-
min_threshold: float = MIN_THRESHOLD
9493
l_p_value: int = 2
9594
linear_collapsing: bool = True
9695
residual_collapsing: bool = True

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,63 +18,69 @@
1818
from mct_quantizers import QuantizationMethod
1919

2020
import model_compression_toolkit.core.common.quantization.quantization_params_generation as qpg
21+
from model_compression_toolkit.constants import MIN_THRESHOLD
2122
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
2223
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
2324
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
2425
from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
25-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod
26+
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationErrorMethod, \
27+
QuantizationConfig
2628

2729

28-
def compute_activation_qparams(activation_quant_cfg: NodeActivationQuantizationConfig,
30+
def compute_activation_qparams(quant_cfg: QuantizationConfig,
31+
node_activation_quant_cfg: NodeActivationQuantizationConfig,
2932
node_prior_info: NodePriorInfo,
3033
out_stats_container: BaseStatsCollector) -> Dict[str, Union[np.ndarray, float, bool]]:
3134
"""
3235
Compute the activations params for a given node in a graph according to a params function.
3336
3437
Args:
35-
activation_quant_cfg: node's activation quantization configuration.
38+
quant_cfg: quantization config.
39+
node_activation_quant_cfg: node's activation quantization configuration.
3640
node_prior_info: Prior info collected for the node that is being quantized.
3741
out_stats_container: Tensor containing output statistics of the node.
3842
3943
Returns:
4044
The computed activation quantization params.
4145
"""
4246
activation_quantization_params_fn = _get_activation_quantization_params_fn(
43-
activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
47+
node_activation_quant_cfg.activation_quantization_method, no_clipping=node_prior_info.is_output_bounded())
4448

4549
# Extract and filter histogram data from the statistics container.
46-
bins_values, bins_counts = _get_histogram_data(activation_quant_cfg, out_stats_container)
50+
bins_values, bins_counts = _get_histogram_data(out_stats_container,
51+
activation_error_method=quant_cfg.activation_error_method,
52+
z_threshold=quant_cfg.z_threshold)
4753

4854
# Retrieve the minimum and maximum values from the statistics container.
4955
min_value, max_value = out_stats_container.get_min_max_values()
5056

5157
# Determine if the activations should be considered signed.
52-
signed = _determine_signedness(activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
58+
signed = _determine_signedness(node_activation_quant_cfg, node_prior_info, min_value, bins_values, bins_counts)
5359

5460
# Compute and return the activation quantization parameters.
5561
return activation_quantization_params_fn(
5662
bins_values,
5763
bins_counts,
58-
activation_quant_cfg.l_p_value,
59-
activation_quant_cfg.activation_n_bits,
64+
quant_cfg.l_p_value,
65+
node_activation_quant_cfg.activation_n_bits,
6066
min_value,
6167
max_value,
62-
min_threshold=activation_quant_cfg.min_threshold,
63-
quant_error_method=activation_quant_cfg.activation_error_method,
68+
min_threshold=MIN_THRESHOLD,
69+
quant_error_method=quant_cfg.activation_error_method,
6470
is_signed=signed
6571
)
6672

6773

68-
def _get_histogram_data(
69-
activation_quant_cfg: NodeActivationQuantizationConfig,
70-
out_stats_container: BaseStatsCollector
71-
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
74+
def _get_histogram_data(out_stats_container: BaseStatsCollector,
75+
activation_error_method: QuantizationErrorMethod,
76+
z_threshold: float) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
7277
"""
7378
Extract and filter the histogram data from the statistics container.
7479
7580
Args:
76-
activation_quant_cfg: Node's activation quantization configuration.
7781
out_stats_container: Statistics container with histogram data.
82+
activation_error_method: activation quantization error method.
83+
z_threshold: z threshold for z-score filtering.
7884
7985
Returns:
8086
A tuple containing the filtered bins_values and bins_counts.
@@ -83,12 +89,12 @@ def _get_histogram_data(
8389
# If the statistics container collected the histogram, we start by filtering outliers using z threshold
8490
# filtering, and then computing the threshold based on the filtered histogram.
8591
if out_stats_container.require_collection():
86-
if activation_quant_cfg.activation_error_method == QuantizationErrorMethod.HMSE:
92+
if activation_error_method == QuantizationErrorMethod.HMSE:
8793
bins_values, bins_counts = out_stats_container.weighted_hc.get_histogram()
8894
else:
8995
bins_values, bins_counts = out_stats_container.hc.get_histogram()
9096
bins_counts = qpg.z_score_filter(
91-
activation_quant_cfg.z_threshold,
97+
z_threshold,
9298
bins_values,
9399
bins_counts
94100
)

0 commit comments

Comments
 (0)