1515from typing import Dict , List , Any , Tuple
1616
1717from mct_quantizers .common .base_inferable_quantizer import BaseInferableQuantizer
18- from mct_quantizers .common .constants import FOUND_TF , ACTIVATION_QUANTIZERS , WEIGHTS_QUANTIZERS , STEPS , LAYER , TRAINING
18+ from mct_quantizers .common .constants import FOUND_TF , WEIGHTS_QUANTIZERS , STEPS , LAYER , TRAINING
1919from mct_quantizers .logger import Logger
2020from mct_quantizers .common .get_all_subclasses import get_all_subclasses
2121
@@ -54,20 +54,17 @@ class KerasQuantizationWrapper(tf.keras.layers.Wrapper):
5454 def __init__ (self ,
5555 layer ,
5656 weights_quantizers : Dict [str , BaseInferableQuantizer ] = None ,
57- activation_quantizers : List [BaseInferableQuantizer ] = None ,
5857 ** kwargs ):
5958 """
6059 Keras Quantization Wrapper takes a keras layer and quantizers and infer a quantized layer.
6160
6261 Args:
6362 layer: A keras layer.
6463 weights_quantizers: A dictionary between a weight's name to its quantizer.
65- activation_quantizers: A list of activations quantization, one for each layer output.
6664 """
6765 super (KerasQuantizationWrapper , self ).__init__ (layer , ** kwargs )
6866 self ._track_trackable (layer , name = 'layer' )
6967 self .weights_quantizers = weights_quantizers if weights_quantizers is not None else dict ()
70- self .activation_quantizers = activation_quantizers if activation_quantizers is not None else list ()
7168
7269 def add_weights_quantizer (self , param_name : str , quantizer : BaseInferableQuantizer ):
7370 """
@@ -82,15 +79,6 @@ def add_weights_quantizer(self, param_name: str, quantizer: BaseInferableQuantiz
8279 """
8380 self .weights_quantizers .update ({param_name : quantizer })
8481
85- @property
86- def is_activation_quantization (self ) -> bool :
87- """
88- This function check activation quantizer exists in wrapper.
89- Returns: a boolean if activation quantizer exists
90-
91- """
92- return self .num_activation_quantizers > 0
93-
9482 @property
9583 def is_weights_quantization (self ) -> bool :
9684 """
@@ -108,22 +96,13 @@ def num_weights_quantizers(self) -> int:
10896 """
10997 return len (self .weights_quantizers )
11098
111- @property
112- def num_activation_quantizers (self ) -> int :
113- """
114- Returns: number of activations quantizers
115- """
116- return len (self .activation_quantizers )
117-
11899 def get_config (self ):
119100 """
120101 Returns: Configuration of KerasQuantizationWrapper.
121102
122103 """
123104 base_config = super (KerasQuantizationWrapper , self ).get_config ()
124- config = {
125- ACTIVATION_QUANTIZERS : [keras .utils .serialize_keras_object (act ) for act in self .activation_quantizers ],
126- WEIGHTS_QUANTIZERS : {k : keras .utils .serialize_keras_object (v ) for k , v in self .weights_quantizers .items ()}}
105+ config = {WEIGHTS_QUANTIZERS : {k : keras .utils .serialize_keras_object (v ) for k , v in self .weights_quantizers .items ()}}
127106 return dict (list (base_config .items ()) + list (config .items ()))
128107
129108 def _set_weights_vars (self , is_training : bool = True ):
@@ -143,17 +122,6 @@ def _set_weights_vars(self, is_training: bool = True):
143122 self ._weights_vars .append ((name , weight , quantizer ))
144123 self ._trainable_weights .append (weight ) # Must when inherit from tf.keras.layers.Wrapper in tf2.10 and below
145124
146- def _set_activations_vars (self ):
147- """
148- This function sets activations quantizers vars to the layer
149-
150- Returns: None
151- """
152- self ._activation_vars = []
153- for i , quantizer in enumerate (self .activation_quantizers ):
154- quantizer .initialize_quantization (None , self .layer .name + f'/out{ i } ' , self )
155- self ._activation_vars .append (quantizer )
156-
157125 @classmethod
158126 def from_config (cls , config ):
159127 """
@@ -167,14 +135,11 @@ def from_config(cls, config):
167135 config = config .copy ()
168136 qi_inferable_custom_objects = {subclass .__name__ : subclass for subclass in
169137 get_all_subclasses (BaseKerasInferableQuantizer )}
170- activation_quantizers = [keras .utils .deserialize_keras_object (act ,
171- module_objects = globals (),
172- custom_objects = None ) for act in config .pop (ACTIVATION_QUANTIZERS )]
173138 weights_quantizers = {k : keras .utils .deserialize_keras_object (v ,
174139 module_objects = globals (),
175140 custom_objects = qi_inferable_custom_objects ) for k , v in config .pop (WEIGHTS_QUANTIZERS ).items ()}
176141 layer = tf .keras .layers .deserialize (config .pop (LAYER ))
177- return cls (layer = layer , weights_quantizers = weights_quantizers , activation_quantizers = activation_quantizers , ** config )
142+ return cls (layer = layer , weights_quantizers = weights_quantizers , ** config )
178143
179144 def build (self , input_shape ):
180145 """
@@ -194,7 +159,6 @@ def build(self, input_shape):
194159 trainable = False )
195160
196161 self ._set_weights_vars ()
197- self ._set_activations_vars ()
198162
199163 def set_quantize_weights (self , quantized_weights : dict ):
200164 """
@@ -254,29 +218,6 @@ def call(self, inputs, training=None, **kwargs):
254218 else :
255219 outputs = self .layer .call (inputs , ** kwargs )
256220
257- # Quantize all activations if quantizers exist.
258- if self .is_activation_quantization :
259- num_outputs = len (outputs ) if isinstance (outputs , (list , tuple )) else 1
260- if self .num_activation_quantizers != num_outputs :
261- Logger .error ('Quantization wrapper output quantization error: '
262- f'number of outputs and quantizers mismatch ({ num_outputs } !='
263- f'{ self .num_activation_quantizers } ' )
264- if num_outputs == 1 :
265- outputs = [outputs ]
266-
267- _outputs = []
268- for _output , act_quant in zip (outputs , self .activation_quantizers ):
269- activation_quantizer_args_spec = tf_inspect .getfullargspec (act_quant .__call__ ).args
270- if TRAINING in activation_quantizer_args_spec :
271- _outputs .append (utils .smart_cond (
272- training ,
273- _make_quantizer_fn (act_quant , _output , True ),
274- _make_quantizer_fn (act_quant , _output , False )))
275- else :
276- # Keras activation inferable quantizer.
277- _outputs .append (act_quant (_output ))
278- outputs = _outputs [0 ] if num_outputs == 1 else _outputs
279-
280221 return outputs
281222
282223 def convert_to_inferable_quantizers (self ):
@@ -286,14 +227,6 @@ def convert_to_inferable_quantizers(self):
286227 Returns:
287228 None
288229 """
289- # Activations quantizers
290- inferable_activation_quantizers = []
291- if self .is_activation_quantization :
292- for quantizer in self .activation_quantizers :
293- if hasattr (quantizer , 'convert2inferable' ) and callable (quantizer .convert2inferable ):
294- inferable_activation_quantizers .append (quantizer .convert2inferable ())
295- self .activation_quantizers = inferable_activation_quantizers
296-
297230 # Weight quantizers
298231 inferable_weight_quantizers = {}
299232 if self .is_weights_quantization :
@@ -310,7 +243,7 @@ def convert_to_inferable_quantizers(self):
310243 layer_weights_list .append (getattr (self .layer , weight_attr )) # quantized weights
311244 layer_weights_list .extend (self .layer .get_weights ()) # non quantized weights
312245 inferable_quantizers_wrapper .layer .set_weights (layer_weights_list )
313- inferable_quantizers_wrapper . _set_activations_vars ()
246+
314247 # The wrapper inference is using the weights of the quantizers so it expectes to create them by running _set_weights_vars
315248 inferable_quantizers_wrapper ._set_weights_vars (False )
316249 return inferable_quantizers_wrapper
@@ -342,15 +275,13 @@ def get_quantized_weights(self) -> Dict[str, tf.Tensor]:
342275 class KerasQuantizationWrapper (object ):
343276 def __init__ (self ,
344277 layer ,
345- weights_quantizers : Dict [str , BaseInferableQuantizer ] = None ,
346- activation_quantizers : List [BaseInferableQuantizer ] = None ):
278+ weights_quantizers : Dict [str , BaseInferableQuantizer ] = None ):
347279 """
348280 Keras Quantization Wrapper takes a keras layer and quantizers and infer a quantized layer.
349281
350282 Args:
351283 layer: A keras layer.
352284 weights_quantizers: A dictionary between a weight's name to its quantizer.
353- activation_quantizers: A list of activations quantization, one for each layer output.
354285 """
355286 Logger .critical ('Installing tensorflow and tensorflow_model_optimization is mandatory '
356287 'when using KerasQuantizationWrapper. '
0 commit comments