@@ -44,9 +44,10 @@ class ScopeFilterTest(BaseKerasFeatureNetworkTest):
4444 - Check attribute changes
4545 '''
4646
47- def __init__ (self , unit_test , activation_n_bits : int = 3 , weights_n_bits : int = 3 ):
48- self .activation_n_bits = activation_n_bits
49- self .weights_n_bits = weights_n_bits
47+ def __init__ (self , unit_test ):
48+ self .activation_n_bits = 5
49+ self .weights_n_bits = 3
50+ self .weights_n_bits2 = 2
5051 self .kernel = 3
5152 self .num_conv_channels = 4
5253 self .scope = 'scope'
@@ -73,12 +74,9 @@ def get_debug_config(self):
7374 EditRule (filter = NodeNameScopeFilter (self .scope ),
7475 action = ChangeCandidatesWeightsQuantConfigAttr (attr_name = KERNEL ,
7576 weights_n_bits = self .weights_n_bits )),
76- EditRule (filter = NodeNameScopeFilter ('change_2' ),
77- action = ChangeCandidatesWeightsQuantConfigAttr (attr_name = KERNEL ,
78- enable_weights_quantization = True )),
7977 EditRule (filter = NodeNameScopeFilter ('change_2' ) or NodeNameScopeFilter ('does_not_exist' ),
8078 action = ChangeCandidatesWeightsQuantConfigAttr (attr_name = KERNEL ,
81- enable_weights_quantization = False ))
79+ weights_n_bits = self . weights_n_bits2 ))
8280 ]
8381 return mct .core .DebugConfig (network_editor = network_editor )
8482
@@ -107,10 +105,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
107105 self .unit_test .assertTrue (
108106 len (np .unique (conv_layers [1 ].get_quantized_weights ()['kernel' ].numpy ())) in [2 ** (self .weights_n_bits ) - 1 ,
109107 2 ** (self .weights_n_bits )])
108+ self .unit_test .assertTrue (
109+ len (np .unique (conv_layers [2 ].get_quantized_weights ()['kernel' ].numpy ())) in [2 ** (self .weights_n_bits2 ) - 1 ,
110+ 2 ** (self .weights_n_bits2 )])
111+
110112 # check that this conv's weights did not change
111113 self .unit_test .assertTrue (np .all (conv_layers [0 ].get_quantized_weights ()['kernel' ].numpy () == self .conv_w ))
112- # check that this conv's weights did not change
113- self .unit_test .assertTrue (np .all (conv_layers [2 ].kernel == self .conv_w ))
114114 holder_layers = get_layers_from_model_by_type (quantized_model , KerasActivationQuantizationHolder )
115115 self .unit_test .assertTrue (holder_layers [1 ].activation_holder_quantizer .get_config ()['num_bits' ] == 16 )
116116 self .unit_test .assertTrue (holder_layers [2 ].activation_holder_quantizer .get_config ()['num_bits' ] == self .activation_n_bits )
0 commit comments