1818from model_compression_toolkit .core .common .framework_info import ChannelAxisMapping
1919from model_compression_toolkit .logger import Logger
2020
21- from model_compression_toolkit .core .common .quantization .quantization_config import QuantizationConfig
2221from model_compression_toolkit .target_platform_capabilities .constants import POSITIONAL_ATTR
2322from model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema import \
2423 AttributeQuantizationConfig , OpQuantizationConfig
@@ -41,6 +40,7 @@ class ActivationQuantizationMode(Enum):
4140 NO_QUANT = auto ()
4241 FLN_NO_QUANT = auto ()
4342
43+
4444class BaseNodeQuantizationConfig (object ):
4545 """
4646 Base class for node quantization configuration
@@ -59,12 +59,11 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val
5959 kwargs: A dictionary with additional key arguments.
6060
6161 """
62-
6362 if hasattr (self , config_parameter_name ):
6463 setattr (self , config_parameter_name , config_parameter_value )
6564 else :
66- Logger . warning ( f"Parameter { config_parameter_name } could not be found in the node quantization config and "
67- f"was not updated! " )
65+ raise AttributeError (
66+ f"Parameter { config_parameter_name } could not be found in the node quantization config. " )
6867
6968 def __repr__ (self ) -> str :
7069 """
@@ -97,37 +96,9 @@ def __init__(self, op_cfg: OpQuantizationConfig):
9796 self .signedness = op_cfg .signedness
9897
9998 self .activation_quantization_params = {}
100- # TODO irena : computed by compute_activation_bias_correction. shouldnt really be here
99+ # TODO: computed by compute_activation_bias_correction. Probably shouldnt be here.
101100 self .activation_bias_correction_term = None
102101
103- # TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
104- self .activation_error_method = None
105- self .relu_bound_to_power_of_2 = None
106- self .activation_channel_equalization = None
107- self .input_scaling = None
108- self .min_threshold = None
109- self .l_p_value = None
110- self .shift_negative_activation_correction = None
111- self .z_threshold = None
112- self .shift_negative_ratio = None
113- self .shift_negative_threshold_recalculation = None
114- self .concat_threshold_update = None
115-
116- def set_qc (self , qc : QuantizationConfig ):
117- """ TODO irena: temporary keep all the attributes as before not to break all code at once.
118- Eventually all of them should be removed from here. """
119- self .activation_error_method = qc .activation_error_method
120- self .relu_bound_to_power_of_2 = qc .relu_bound_to_power_of_2
121- self .activation_channel_equalization = qc .activation_channel_equalization
122- self .input_scaling = qc .input_scaling
123- self .min_threshold = qc .min_threshold
124- self .l_p_value = qc .l_p_value
125- self .shift_negative_activation_correction = qc .shift_negative_activation_correction
126- self .z_threshold = qc .z_threshold
127- self .shift_negative_ratio = qc .shift_negative_ratio
128- self .shift_negative_threshold_recalculation = qc .shift_negative_threshold_recalculation
129- self .concat_threshold_update = qc .concat_threshold_update
130-
131102 @property
132103 def enable_activation_quantization (self ):
133104 return self .quant_mode == ActivationQuantizationMode .QUANT
@@ -165,32 +136,16 @@ def __eq__(self, other: Any) -> bool:
165136 if not isinstance (other , NodeActivationQuantizationConfig ):
166137 return False # pragma: no cover
167138
168- return self .activation_error_method == other .activation_error_method and \
169- self .activation_quantization_method == other .activation_quantization_method and \
139+ return self .activation_quantization_method == other .activation_quantization_method and \
170140 self .activation_n_bits == other .activation_n_bits and \
171141 self .quant_mode == other .quant_mode and \
172- self .activation_channel_equalization == other .activation_channel_equalization and \
173- self .input_scaling == other .input_scaling and \
174- self .min_threshold == other .min_threshold and \
175- self .l_p_value == other .l_p_value and \
176- self .shift_negative_activation_correction == other .shift_negative_activation_correction and \
177- self .z_threshold == other .z_threshold and \
178- self .shift_negative_ratio == other .shift_negative_ratio and \
179- self .shift_negative_threshold_recalculation == other .shift_negative_threshold_recalculation
142+ self .signedness == other .signedness
180143
181144 def __hash__ (self ):
182- return hash ((self .activation_error_method ,
183- self .activation_quantization_method ,
145+ return hash ((self .activation_quantization_method ,
184146 self .activation_n_bits ,
185147 self .quant_mode ,
186- self .activation_channel_equalization ,
187- self .input_scaling ,
188- self .min_threshold ,
189- self .l_p_value ,
190- self .shift_negative_activation_correction ,
191- self .z_threshold ,
192- self .shift_negative_ratio ,
193- self .shift_negative_threshold_recalculation ))
148+ self .signedness ))
194149
195150
196151class WeightsAttrQuantizationConfig :
@@ -211,16 +166,8 @@ def __init__(self,
211166 self .weights_n_bits = weights_attr_cfg .weights_n_bits
212167 self .weights_per_channel_threshold = weights_attr_cfg .weights_per_channel_threshold
213168 self .enable_weights_quantization = weights_attr_cfg .enable_weights_quantization
214- self .weights_quantization_params = {}
215169
216- # TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
217- self .weights_error_method = None
218- self .l_p_value = None
219-
220- def set_qc (self , qc : QuantizationConfig ):
221- # TODO irena: temporary keep the fields to not break everything at once.
222- self .weights_error_method = qc .weights_error_method
223- self .l_p_value = qc .l_p_value
170+ self .weights_quantization_params = {}
224171
225172 def set_weights_quantization_param (self ,
226173 weights_params : dict ):
@@ -252,18 +199,14 @@ def __eq__(self, other: Any) -> bool:
252199 self .weights_quantization_method == other .weights_quantization_method and \
253200 self .weights_n_bits == other .weights_n_bits and \
254201 self .weights_per_channel_threshold == other .weights_per_channel_threshold and \
255- self .enable_weights_quantization == other .enable_weights_quantization and \
256- self .weights_error_method == other .weights_error_method and \
257- self .l_p_value == other .l_p_value
202+ self .enable_weights_quantization == other .enable_weights_quantization
258203
259204 def __hash__ (self ):
260205 return hash ((self .weights_channels_axis ,
261- self .weights_error_method ,
262206 self .weights_quantization_method ,
263207 self .weights_n_bits ,
264208 self .weights_per_channel_threshold ,
265- self .enable_weights_quantization ,
266- self .l_p_value ))
209+ self .enable_weights_quantization ))
267210
268211
269212class NodeWeightsQuantizationConfig (BaseNodeQuantizationConfig ):
@@ -330,16 +273,14 @@ def __init__(self,
330273
331274 self .attributes_config_mapping [attr ] = WeightsAttrQuantizationConfig (weights_attr_cfg = attr_cfg ,
332275 weights_channels_axis = weights_channels_axis )
333- # TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
334- self .min_threshold = None
276+ # TODO this is set by batch norm reconstruction substitution when folded batch norms are added back, to mark
277+ # the nodes that the correction should be applied to (for some nodes it gets disabled) and BNs removed.
278+ # The actual correction is only computed when it's applied in ptq, so it seems that both substitutions could
279+ # be unified, and no info need to pass between.
335280 self .weights_second_moment_correction = None
336- self .weights_bias_correction = None
337-
338- def set_qc (self , qc : QuantizationConfig ):
339- # TODO irena: temporary keep the fields to not break everything at once.
340- self .min_threshold = qc .min_threshold
341- self .weights_second_moment_correction = qc .weights_second_moment_correction
342- self .weights_bias_correction = qc .weights_bias_correction
281+ # TODO: computed corrected bias is injected to the node config. Probably shouldn't be here. Also it can be
282+ # computed on the final config, instead of all candidates and then there is no need to save it at all.
283+ self .bias_corrected = None
343284
344285 def get_attr_config (self , attr_name : 'WeightAttrT' ) -> WeightsAttrQuantizationConfig :
345286 """
@@ -476,8 +417,8 @@ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_val
476417 if hasattr (attr_cfg , config_parameter_name ):
477418 setattr (attr_cfg , config_parameter_name , config_parameter_value )
478419 else :
479- Logger . warning (f"Parameter { config_parameter_name } could not be found in the node quantization config of "
480- f"weights attribute { attr_name } and was not updated! " )
420+ raise AttributeError (f"Parameter { config_parameter_name } could not be found in the node quantization config of "
421+ f"weights attribute { attr_name } . " )
481422 else : # pragma: no cover
482423 Logger .critical (f"Weights attribute { attr_name } could not be found to set parameter { config_parameter_name } ." )
483424
@@ -494,10 +435,7 @@ def __eq__(self, other: Any) -> bool:
494435 if not isinstance (other , NodeWeightsQuantizationConfig ):
495436 return False # pragma: no cover
496437
497- return self .min_threshold == other .min_threshold and \
498- self .simd_size == other .simd_size and \
499- self .weights_second_moment_correction == other .weights_second_moment_correction and \
500- self .weights_bias_correction == other .weights_bias_correction and \
438+ return self .simd_size == other .simd_size and \
501439 self .attributes_config_mapping .keys () == other .attributes_config_mapping .keys () and \
502440 all ([self .attributes_config_mapping [k ] == other .attributes_config_mapping [k ]
503441 for k in self .attributes_config_mapping .keys ()]) and \
@@ -506,9 +444,6 @@ def __eq__(self, other: Any) -> bool:
506444 for k in self .pos_attributes_config_mapping .keys ()])
507445
508446 def __hash__ (self ):
509- return hash ((self .min_threshold ,
510- self .simd_size ,
511- self .weights_second_moment_correction ,
512- self .weights_bias_correction ,
447+ return hash ((self .simd_size ,
513448 frozenset (self .attributes_config_mapping ),
514449 frozenset (self .pos_attributes_config_mapping )))
0 commit comments