2121
2222from mct_quantizers import QuantizationMethod , KerasQuantizationWrapper
2323from mct_quantizers .keras .metadata import MetadataLayer
24+ from mct_quantizers .keras .quantizers import WeightsPOTInferableQuantizer , WeightsSymmetricInferableQuantizer , \
25+ WeightsUniformInferableQuantizer
2426from model_compression_toolkit .core .common .user_info import UserInformation
25- from model_compression_toolkit .core .keras .constants import KERNEL , DEPTHWISE_KERNEL
27+ from model_compression_toolkit .core .keras .constants import KERNEL
2628from model_compression_toolkit .ptq import keras_post_training_quantization
2729from model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema import OpQuantizationConfig , \
2830 AttributeQuantizationConfig , Signedness
@@ -115,22 +117,23 @@ def tpc(quant_method, per_channel):
115117 return _get_tpc (quant_method , per_channel )
116118
117119
118- @pytest .fixture (params = [model_basic , model_residual ])
119- def model (request ):
120- return request .param ()
120+ @pytest .fixture (params = [( model_basic , { "expected_num_quantized" : 1 }), ( model_residual , { "expected_num_quantized" : 3 }) ])
121+ def model_scenario (request ):
122+ return request .param
121123
122124
123125class TestPTQWithQuantizationMethods :
124126 # TODO: add tests for:
125127 # 1) activation only, W&A, LUT quantizer (separate)
126128 # 2) advanced models and operators
127129
128- def test_ptq_weights_only_quantization_methods (self , model , rep_data_gen , quant_method , per_channel , tpc ):
129-
130+ def test_ptq_weights_only_quantization_methods (self , model_scenario , rep_data_gen , quant_method , per_channel , tpc ):
131+ model , expected_values = model_scenario
132+ model = model ()
130133 q_model , quantization_info = keras_post_training_quantization (model , rep_data_gen ,
131134 target_platform_capabilities = tpc )
132135
133- self ._verify_quantized_model_structure (model , q_model , quantization_info )
136+ self ._verify_quantized_model_structure (q_model , quantization_info , expected_values [ 'expected_num_quantized' ] )
134137
135138 # Assert quantization properties
136139 quantized_conv_layers = [l for l in q_model .layers if isinstance (l , KerasQuantizationWrapper )]
@@ -150,17 +153,20 @@ def _verify_weights_quantizer_params(weights_quantizer, exp_params_shape, quant_
150153 assert weights_quantizer .quantization_method [0 ] == quant_method
151154
152155 if quant_method == QuantizationMethod .POWER_OF_TWO :
156+ assert isinstance (weights_quantizer , WeightsPOTInferableQuantizer )
153157 assert len (weights_quantizer .threshold ) == exp_params_shape
154158 for t in weights_quantizer .threshold :
155159 assert np .log2 (np .abs (t )).astype (int ) == np .log2 (np .abs (t ))
156160 elif quant_method == QuantizationMethod .SYMMETRIC :
161+ assert isinstance (weights_quantizer , WeightsSymmetricInferableQuantizer )
157162 assert len (weights_quantizer .threshold ) == exp_params_shape
158163 elif quant_method == QuantizationMethod .UNIFORM :
164+ assert isinstance (weights_quantizer , WeightsUniformInferableQuantizer )
159165 assert len (weights_quantizer .min_range ) == exp_params_shape
160166 assert len (weights_quantizer .max_range ) == exp_params_shape
161167
162168 @staticmethod
163- def _verify_quantized_model_structure (model , q_model , quantization_info ):
169+ def _verify_quantized_model_structure (q_model , quantization_info , expected_num_quantized ):
164170 assert isinstance (q_model , keras .Model )
165171 assert quantization_info is not None and isinstance (quantization_info , UserInformation )
166172
@@ -169,9 +175,6 @@ def _verify_quantized_model_structure(model, q_model, quantization_info):
169175 "Expects BN folding in quantized model."
170176 assert len ([l for l in q_model .layers if isinstance (l , MetadataLayer )]) == 1 , \
171177 "Expects quantized model to have a metadata stored in a dedicated layer."
172- original_conv_layers = [l for l in model .layers if
173- isinstance (l , (layers .Conv2D , layers .DepthwiseConv2D , layers .Dense ))]
174- quantized_conv_layers = [l for l in q_model .layers if isinstance (l , KerasQuantizationWrapper )]
175- assert len (original_conv_layers ) > 0
176- assert len (original_conv_layers ) == len (quantized_conv_layers ), \
178+ quantized_layers = [l for l in q_model .layers if isinstance (l , KerasQuantizationWrapper )]
179+ assert len (quantized_layers ) == expected_num_quantized , \
177180 "Expects all conv layers from the original model to be wrapped with a KerasQuantizationWrapper."
0 commit comments