Skip to content

Commit 20c9c4d

Browse files
reuvenperetzreuvenp
andauthored
Remove activation quantizers from Keras wrapper (#25)
* Remove activation quantizers from Keras wrapper * Remove call to _set_activation_vars in Keras wrapper --------- Co-authored-by: reuvenp <reuvenp@altair-semi.com>
1 parent 9c2413f commit 20c9c4d

File tree

2 files changed

+5
-102
lines changed

2 files changed

+5
-102
lines changed

mct_quantizers/keras/quantize_wrapper.py

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Dict, List, Any, Tuple
1616

1717
from mct_quantizers.common.base_inferable_quantizer import BaseInferableQuantizer
18-
from mct_quantizers.common.constants import FOUND_TF, ACTIVATION_QUANTIZERS, WEIGHTS_QUANTIZERS, STEPS, LAYER, TRAINING
18+
from mct_quantizers.common.constants import FOUND_TF, WEIGHTS_QUANTIZERS, STEPS, LAYER, TRAINING
1919
from mct_quantizers.logger import Logger
2020
from mct_quantizers.common.get_all_subclasses import get_all_subclasses
2121

@@ -54,20 +54,17 @@ class KerasQuantizationWrapper(tf.keras.layers.Wrapper):
5454
def __init__(self,
5555
layer,
5656
weights_quantizers: Dict[str, BaseInferableQuantizer] = None,
57-
activation_quantizers: List[BaseInferableQuantizer] = None,
5857
**kwargs):
5958
"""
6059
Keras Quantization Wrapper takes a keras layer and quantizers and infer a quantized layer.
6160
6261
Args:
6362
layer: A keras layer.
6463
weights_quantizers: A dictionary between a weight's name to its quantizer.
65-
activation_quantizers: A list of activations quantization, one for each layer output.
6664
"""
6765
super(KerasQuantizationWrapper, self).__init__(layer, **kwargs)
6866
self._track_trackable(layer, name='layer')
6967
self.weights_quantizers = weights_quantizers if weights_quantizers is not None else dict()
70-
self.activation_quantizers = activation_quantizers if activation_quantizers is not None else list()
7168

7269
def add_weights_quantizer(self, param_name: str, quantizer: BaseInferableQuantizer):
7370
"""
@@ -82,15 +79,6 @@ def add_weights_quantizer(self, param_name: str, quantizer: BaseInferableQuantiz
8279
"""
8380
self.weights_quantizers.update({param_name: quantizer})
8481

85-
@property
86-
def is_activation_quantization(self) -> bool:
87-
"""
88-
This function check activation quantizer exists in wrapper.
89-
Returns: a boolean if activation quantizer exists
90-
91-
"""
92-
return self.num_activation_quantizers > 0
93-
9482
@property
9583
def is_weights_quantization(self) -> bool:
9684
"""
@@ -108,22 +96,13 @@ def num_weights_quantizers(self) -> int:
10896
"""
10997
return len(self.weights_quantizers)
11098

111-
@property
112-
def num_activation_quantizers(self) -> int:
113-
"""
114-
Returns: number of activations quantizers
115-
"""
116-
return len(self.activation_quantizers)
117-
11899
def get_config(self):
119100
"""
120101
Returns: Configuration of KerasQuantizationWrapper.
121102
122103
"""
123104
base_config = super(KerasQuantizationWrapper, self).get_config()
124-
config = {
125-
ACTIVATION_QUANTIZERS: [keras.utils.serialize_keras_object(act) for act in self.activation_quantizers],
126-
WEIGHTS_QUANTIZERS: {k: keras.utils.serialize_keras_object(v) for k, v in self.weights_quantizers.items()}}
105+
config = {WEIGHTS_QUANTIZERS: {k: keras.utils.serialize_keras_object(v) for k, v in self.weights_quantizers.items()}}
127106
return dict(list(base_config.items()) + list(config.items()))
128107

129108
def _set_weights_vars(self, is_training: bool = True):
@@ -143,17 +122,6 @@ def _set_weights_vars(self, is_training: bool = True):
143122
self._weights_vars.append((name, weight, quantizer))
144123
self._trainable_weights.append(weight) # Must when inherit from tf.keras.layers.Wrapper in tf2.10 and below
145124

146-
def _set_activations_vars(self):
147-
"""
148-
This function sets activations quantizers vars to the layer
149-
150-
Returns: None
151-
"""
152-
self._activation_vars = []
153-
for i, quantizer in enumerate(self.activation_quantizers):
154-
quantizer.initialize_quantization(None, self.layer.name + f'/out{i}', self)
155-
self._activation_vars.append(quantizer)
156-
157125
@classmethod
158126
def from_config(cls, config):
159127
"""
@@ -167,14 +135,11 @@ def from_config(cls, config):
167135
config = config.copy()
168136
qi_inferable_custom_objects = {subclass.__name__: subclass for subclass in
169137
get_all_subclasses(BaseKerasInferableQuantizer)}
170-
activation_quantizers = [keras.utils.deserialize_keras_object(act,
171-
module_objects=globals(),
172-
custom_objects=None) for act in config.pop(ACTIVATION_QUANTIZERS)]
173138
weights_quantizers = {k: keras.utils.deserialize_keras_object(v,
174139
module_objects=globals(),
175140
custom_objects=qi_inferable_custom_objects) for k, v in config.pop(WEIGHTS_QUANTIZERS).items()}
176141
layer = tf.keras.layers.deserialize(config.pop(LAYER))
177-
return cls(layer=layer, weights_quantizers=weights_quantizers, activation_quantizers=activation_quantizers, **config)
142+
return cls(layer=layer, weights_quantizers=weights_quantizers, **config)
178143

179144
def build(self, input_shape):
180145
"""
@@ -194,7 +159,6 @@ def build(self, input_shape):
194159
trainable=False)
195160

196161
self._set_weights_vars()
197-
self._set_activations_vars()
198162

199163
def set_quantize_weights(self, quantized_weights: dict):
200164
"""
@@ -254,29 +218,6 @@ def call(self, inputs, training=None, **kwargs):
254218
else:
255219
outputs = self.layer.call(inputs, **kwargs)
256220

257-
# Quantize all activations if quantizers exist.
258-
if self.is_activation_quantization:
259-
num_outputs = len(outputs) if isinstance(outputs, (list, tuple)) else 1
260-
if self.num_activation_quantizers != num_outputs:
261-
Logger.error('Quantization wrapper output quantization error: '
262-
f'number of outputs and quantizers mismatch ({num_outputs}!='
263-
f'{self.num_activation_quantizers}')
264-
if num_outputs == 1:
265-
outputs = [outputs]
266-
267-
_outputs = []
268-
for _output, act_quant in zip(outputs, self.activation_quantizers):
269-
activation_quantizer_args_spec = tf_inspect.getfullargspec(act_quant.__call__).args
270-
if TRAINING in activation_quantizer_args_spec:
271-
_outputs.append(utils.smart_cond(
272-
training,
273-
_make_quantizer_fn(act_quant, _output, True),
274-
_make_quantizer_fn(act_quant, _output, False)))
275-
else:
276-
# Keras activation inferable quantizer.
277-
_outputs.append(act_quant(_output))
278-
outputs = _outputs[0] if num_outputs == 1 else _outputs
279-
280221
return outputs
281222

282223
def convert_to_inferable_quantizers(self):
@@ -286,14 +227,6 @@ def convert_to_inferable_quantizers(self):
286227
Returns:
287228
None
288229
"""
289-
# Activations quantizers
290-
inferable_activation_quantizers = []
291-
if self.is_activation_quantization:
292-
for quantizer in self.activation_quantizers:
293-
if hasattr(quantizer, 'convert2inferable') and callable(quantizer.convert2inferable):
294-
inferable_activation_quantizers.append(quantizer.convert2inferable())
295-
self.activation_quantizers = inferable_activation_quantizers
296-
297230
# Weight quantizers
298231
inferable_weight_quantizers = {}
299232
if self.is_weights_quantization:
@@ -310,7 +243,7 @@ def convert_to_inferable_quantizers(self):
310243
layer_weights_list.append(getattr(self.layer, weight_attr)) # quantized weights
311244
layer_weights_list.extend(self.layer.get_weights()) # non quantized weights
312245
inferable_quantizers_wrapper.layer.set_weights(layer_weights_list)
313-
inferable_quantizers_wrapper._set_activations_vars()
246+
314247
# The wrapper inference is using the weights of the quantizers so it expectes to create them by running _set_weights_vars
315248
inferable_quantizers_wrapper._set_weights_vars(False)
316249
return inferable_quantizers_wrapper
@@ -342,15 +275,13 @@ def get_quantized_weights(self) -> Dict[str, tf.Tensor]:
342275
class KerasQuantizationWrapper(object):
343276
def __init__(self,
344277
layer,
345-
weights_quantizers: Dict[str, BaseInferableQuantizer] = None,
346-
activation_quantizers: List[BaseInferableQuantizer] = None):
278+
weights_quantizers: Dict[str, BaseInferableQuantizer] = None):
347279
"""
348280
Keras Quantization Wrapper takes a keras layer and quantizers and infer a quantized layer.
349281
350282
Args:
351283
layer: A keras layer.
352284
weights_quantizers: A dictionary between a weight's name to its quantizer.
353-
activation_quantizers: A list of activations quantization, one for each layer output.
354285
"""
355286
Logger.critical('Installing tensorflow and tensorflow_model_optimization is mandatory '
356287
'when using KerasQuantizationWrapper. '

tests/keras_tests/test_keras_quantization_wrapper.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,6 @@ def initialize_quantization(self, tensor_shape, name, layer):
4242
return {}
4343

4444

45-
class ZeroActivationsQuantizer:
46-
"""
47-
A dummy quantizer for test usage - "quantize" the layer's activation to 0
48-
"""
49-
50-
def __call__(self,
51-
inputs: tf.Tensor,
52-
training: bool = True) -> tf.Tensor:
53-
return inputs * 0
54-
55-
def initialize_quantization(self, tensor_shape, name, layer):
56-
return {}
57-
58-
5945
class TestKerasQuantizationWrapper(unittest.TestCase):
6046

6147
def setUp(self):
@@ -86,17 +72,3 @@ def test_weights_quantization_wrapper(self):
8672
outputs = wrapper.call(call_inputs.astype('float32'))
8773
self.assertTrue((outputs == conv_layer(call_inputs)).numpy().all())
8874

89-
def test_activation_quantization_wrapper(self):
90-
conv_layer = self.model.layers[1]
91-
92-
wrapper = KerasQuantizationWrapper(conv_layer, activation_quantizers=[ZeroActivationsQuantizer()])
93-
94-
# build
95-
wrapper.build(self.input_shapes)
96-
(act_quantizer) = wrapper._activation_vars[0]
97-
self.assertTrue(isinstance(act_quantizer, ZeroActivationsQuantizer))
98-
99-
# apply the wrapper on inputs
100-
call_inputs = self.inputs[0]
101-
outputs = wrapper.call(call_inputs.astype('float32'))
102-
self.assertTrue((outputs == 0).numpy().all())

0 commit comments

Comments
 (0)