Skip to content

Commit fc7a6f8

Browse files
author
Ofir Gordon
committed
extend test verifications
1 parent 6b1d66e commit fc7a6f8

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

tests_pytest/keras_tests/e2e_tests/post_training_quantization_tests/test_quantization_methods.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121

2222
from mct_quantizers import QuantizationMethod, KerasQuantizationWrapper
2323
from mct_quantizers.keras.metadata import MetadataLayer
24+
from mct_quantizers.keras.quantizers import WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, \
25+
WeightsUniformInferableQuantizer
2426
from model_compression_toolkit.core.common.user_info import UserInformation
25-
from model_compression_toolkit.core.keras.constants import KERNEL, DEPTHWISE_KERNEL
27+
from model_compression_toolkit.core.keras.constants import KERNEL
2628
from model_compression_toolkit.ptq import keras_post_training_quantization
2729
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
2830
AttributeQuantizationConfig, Signedness
@@ -115,22 +117,23 @@ def tpc(quant_method, per_channel):
115117
return _get_tpc(quant_method, per_channel)
116118

117119

118-
@pytest.fixture(params=[model_basic, model_residual])
119-
def model(request):
120-
return request.param()
120+
@pytest.fixture(params=[(model_basic, {"expected_num_quantized": 1}), (model_residual, {"expected_num_quantized": 3})])
121+
def model_scenario(request):
122+
return request.param
121123

122124

123125
class TestPTQWithQuantizationMethods:
124126
# TODO: add tests for:
125127
# 1) activation only, W&A, LUT quantizer (separate)
126128
# 2) advanced models and operators
127129

128-
def test_ptq_weights_only_quantization_methods(self, model, rep_data_gen, quant_method, per_channel, tpc):
129-
130+
def test_ptq_weights_only_quantization_methods(self, model_scenario, rep_data_gen, quant_method, per_channel, tpc):
131+
model, expected_values = model_scenario
132+
model = model()
130133
q_model, quantization_info = keras_post_training_quantization(model, rep_data_gen,
131134
target_platform_capabilities=tpc)
132135

133-
self._verify_quantized_model_structure(model, q_model, quantization_info)
136+
self._verify_quantized_model_structure(q_model, quantization_info, expected_values['expected_num_quantized'])
134137

135138
# Assert quantization properties
136139
quantized_conv_layers = [l for l in q_model.layers if isinstance(l, KerasQuantizationWrapper)]
@@ -150,17 +153,20 @@ def _verify_weights_quantizer_params(weights_quantizer, exp_params_shape, quant_
150153
assert weights_quantizer.quantization_method[0] == quant_method
151154

152155
if quant_method == QuantizationMethod.POWER_OF_TWO:
156+
assert isinstance(weights_quantizer, WeightsPOTInferableQuantizer)
153157
assert len(weights_quantizer.threshold) == exp_params_shape
154158
for t in weights_quantizer.threshold:
155159
assert np.log2(np.abs(t)).astype(int) == np.log2(np.abs(t))
156160
elif quant_method == QuantizationMethod.SYMMETRIC:
161+
assert isinstance(weights_quantizer, WeightsSymmetricInferableQuantizer)
157162
assert len(weights_quantizer.threshold) == exp_params_shape
158163
elif quant_method == QuantizationMethod.UNIFORM:
164+
assert isinstance(weights_quantizer, WeightsUniformInferableQuantizer)
159165
assert len(weights_quantizer.min_range) == exp_params_shape
160166
assert len(weights_quantizer.max_range) == exp_params_shape
161167

162168
@staticmethod
163-
def _verify_quantized_model_structure(model, q_model, quantization_info):
169+
def _verify_quantized_model_structure(q_model, quantization_info, expected_num_quantized):
164170
assert isinstance(q_model, keras.Model)
165171
assert quantization_info is not None and isinstance(quantization_info, UserInformation)
166172

@@ -169,9 +175,6 @@ def _verify_quantized_model_structure(model, q_model, quantization_info):
169175
"Expects BN folding in quantized model."
170176
assert len([l for l in q_model.layers if isinstance(l, MetadataLayer)]) == 1, \
171177
"Expects quantized model to have a metadata stored in a dedicated layer."
172-
original_conv_layers = [l for l in model.layers if
173-
isinstance(l, (layers.Conv2D, layers.DepthwiseConv2D, layers.Dense))]
174-
quantized_conv_layers = [l for l in q_model.layers if isinstance(l, KerasQuantizationWrapper)]
175-
assert len(original_conv_layers) > 0
176-
assert len(original_conv_layers) == len(quantized_conv_layers), \
178+
quantized_layers = [l for l in q_model.layers if isinstance(l, KerasQuantizationWrapper)]
179+
assert len(quantized_layers) == expected_num_quantized, \
177180
"Expects all conv layers from the original model to be wrapped with a KerasQuantizationWrapper."

0 commit comments

Comments
 (0)