1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414# ==============================================================================
15+ import abc
1516
17+ import typing
1618
1719import numpy as np
1820import tensorflow as tf
@@ -65,10 +67,23 @@ def create_networks(self):
6567 model = keras .Model (inputs = inputs , outputs = outputs )
6668 return model
6769
70+ def get_resource_utilization (self ):
71+ raise NotImplementedError ()
72+
73+ @typing .final
6874 def compare (self , quantized_model , float_model , input_x = None , quantization_info : UserInformation = None ):
69- # This is a base test, so it does not check a thing. Only actual tests of mixed precision
70- # compare things to test.
71- raise NotImplementedError
75+ # call concrete validation of the specific test
76+ self ._compare (quantized_model , float_model , input_x , quantization_info )
77+ # make sure the final utilization satisfies the target constraints
78+ target_ru = self .get_resource_utilization ()
79+ if target_ru .is_any_restricted ():
80+ self .unit_test .assertTrue (
81+ target_ru .is_satisfied_by (quantization_info .final_resource_utilization ))
82+
83+ @abc .abstractmethod
84+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info : UserInformation = None ):
85+ # test-specific validation, to be implemented by each test
86+ raise NotImplementedError ()
7287
7388
7489class MixedPrecisionManuallyConfiguredTest (MixedPrecisionBaseTest ):
@@ -95,7 +110,7 @@ def get_resource_utilization(self):
95110 # set manually)
96111 return ResourceUtilization (1 )
97112
98- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
113+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
99114 assert quantization_info .mixed_precision_cfg == [2 , 1 ]
100115 conv_layers = get_layers_from_model_by_type (quantized_model , layers .Conv2D )
101116 self .unit_test .assertTrue (np .unique (conv_layers [0 ].weights [0 ]).flatten ().shape [0 ] <= 4 )
@@ -114,7 +129,7 @@ def get_mixed_precision_config(self):
114129 return mct .core .MixedPrecisionQuantizationConfig (num_of_images = 1 ,
115130 distance_weighting_method = self .distance_metric )
116131
117- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
132+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
118133 conv_layers = get_layers_from_model_by_type (quantized_model , layers .Conv2D )
119134 self .unit_test .assertTrue (any ([b != 0 for b in quantization_info .mixed_precision_cfg ]),
120135 "At least one of the conv layers is expected to be quantized to meet the required "
@@ -147,7 +162,7 @@ def get_mixed_precision_config(self):
147162 distance_weighting_method = self .distance_metric ,
148163 use_hessian_based_scores = True )
149164
150- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
165+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
151166 conv_layers = get_layers_from_model_by_type (quantized_model , layers .Conv2D )
152167 self .unit_test .assertTrue (any ([b != 0 for b in quantization_info .mixed_precision_cfg ]),
153168 "At least one of the conv layers is expected to be quantized to meet the required "
@@ -220,7 +235,7 @@ def create_networks(self):
220235 def get_resource_utilization (self ):
221236 return ResourceUtilization (1790 )
222237
223- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
238+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
224239 # We just needed to verify that the graph finalization is working without failing.
225240 # The actual quantization is not interesting for the sake of this test, so we just verify some
226241 # degenerated things to see that everything worked.
@@ -242,7 +257,7 @@ def get_resource_utilization(self):
242257 # Resource Utilization is for 4 bits on average
243258 return ResourceUtilization (17920 * 4 / 8 )
244259
245- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
260+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
246261 conv_layers = get_layers_from_model_by_type (quantized_model , layers .Conv2D )
247262 assert (quantization_info .mixed_precision_cfg == [1 , 1 ]).all ()
248263 for i in range (32 ): # quantized per channel
@@ -283,7 +298,7 @@ def create_networks(self):
283298 model = keras .Model (inputs = inputs , outputs = outputs )
284299 return model
285300
286- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
301+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
287302 conv_layers = get_layers_from_model_by_type (quantized_model , layers .Conv2D )
288303 self .unit_test .assertTrue ((quantization_info .mixed_precision_cfg != 0 ).any ())
289304
@@ -308,7 +323,7 @@ def get_resource_utilization(self):
308323 # Resource Utilization is for 2 bits on average
309324 return ResourceUtilization (17920 * 2 / 8 )
310325
311- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
326+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
312327 conv_layers = get_layers_from_model_by_type (quantized_model , layers .Conv2D )
313328 assert (quantization_info .mixed_precision_cfg == [2 , 2 ]).all ()
314329 for i in range (32 ): # quantized per channel
@@ -335,7 +350,7 @@ def __init__(self, unit_test):
335350 def get_resource_utilization (self ):
336351 return self .target_total_ru
337352
338- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
353+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
339354 # No need to verify quantization configuration here since this test is similar to other tests we have,
340355 # we're only interested in the ResourceUtilization
341356 self .unit_test .assertTrue (quantization_info .final_resource_utilization .activation_memory <=
@@ -351,7 +366,7 @@ def __init__(self, unit_test):
351366 def get_resource_utilization (self ):
352367 return self .target_total_ru
353368
354- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
369+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
355370 # No need to verify quantization configuration here since this test is similar to other tests we have,
356371 # we're only interested in the ResourceUtilization
357372 self .unit_test .assertTrue (
@@ -373,7 +388,7 @@ def create_networks(self):
373388 model = keras .Model (inputs = inputs , outputs = x )
374389 return model
375390
376- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
391+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
377392 self .unit_test .assertTrue (len (quantization_info .mixed_precision_cfg ) == 1 )
378393 self .unit_test .assertTrue (quantization_info .mixed_precision_cfg [0 ] == 1 )
379394
@@ -426,7 +441,7 @@ def get_resource_utilization(self):
426441 # resource utilization is infinity -> should give best model - 8bits
427442 return ResourceUtilization (17919 )
428443
429- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
444+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
430445 conv_layers = get_layers_from_model_by_type (quantized_model , layers .Conv2D )
431446 assert (quantization_info .mixed_precision_cfg == [0 , 1 ]).all ()
432447 for i in range (32 ): # quantized per channel
@@ -449,7 +464,7 @@ def get_mixed_precision_config(self):
449464 def get_resource_utilization (self ):
450465 return ResourceUtilization (17919 )
451466
452- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
467+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
453468 conv_layers = get_layers_from_model_by_type (quantized_model , layers .Conv2D )
454469 assert any ([(quantization_info .mixed_precision_cfg == [1 , 0 ]).all (),
455470 (quantization_info .mixed_precision_cfg == [0 , 1 ]).all ()])
@@ -526,7 +541,7 @@ def get_tpc(self):
526541 def get_resource_utilization (self ):
527542 return ResourceUtilization (1535 )
528543
529- def compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
544+ def _compare (self , quantized_model , float_model , input_x = None , quantization_info = None ):
530545 wrapper_layers = get_layers_from_model_by_type (quantized_model , KerasQuantizationWrapper )
531546 weights_bits = wrapper_layers [0 ].weights_quantizers [KERNEL ].num_bits
532547 self .unit_test .assertTrue (weights_bits == 4 )
0 commit comments