Skip to content

Commit 9c76b53

Browse files
irenabirenab
authored andcommitted
fix tests
1 parent cef47e1 commit 9c76b53

4 files changed

Lines changed: 10 additions & 14 deletions

File tree

tests/keras_tests/feature_networks_tests/feature_networks/network_editor/node_filter_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ class ScopeFilterTest(BaseKerasFeatureNetworkTest):
4444
- Check attribute changes
4545
'''
4646

47-
def __init__(self, unit_test, activation_n_bits: int = 3, weights_n_bits: int = 3):
48-
self.activation_n_bits = activation_n_bits
49-
self.weights_n_bits = weights_n_bits
47+
def __init__(self, unit_test):
48+
self.activation_n_bits = 5
49+
self.weights_n_bits = 3
50+
self.weights_n_bits2 = 2
5051
self.kernel = 3
5152
self.num_conv_channels = 4
5253
self.scope = 'scope'
@@ -73,12 +74,9 @@ def get_debug_config(self):
7374
EditRule(filter=NodeNameScopeFilter(self.scope),
7475
action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL,
7576
weights_n_bits=self.weights_n_bits)),
76-
EditRule(filter=NodeNameScopeFilter('change_2'),
77-
action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL,
78-
enable_weights_quantization=True)),
7977
EditRule(filter=NodeNameScopeFilter('change_2') or NodeNameScopeFilter('does_not_exist'),
8078
action=ChangeCandidatesWeightsQuantConfigAttr(attr_name=KERNEL,
81-
enable_weights_quantization=False))
79+
weights_n_bits=self.weights_n_bits2))
8280
]
8381
return mct.core.DebugConfig(network_editor=network_editor)
8482

@@ -107,10 +105,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
107105
self.unit_test.assertTrue(
108106
len(np.unique(conv_layers[1].get_quantized_weights()['kernel'].numpy())) in [2 ** (self.weights_n_bits) - 1,
109107
2 ** (self.weights_n_bits)])
108+
self.unit_test.assertTrue(
109+
len(np.unique(conv_layers[2].get_quantized_weights()['kernel'].numpy())) in [2 ** (self.weights_n_bits2) - 1,
110+
2 ** (self.weights_n_bits2)])
111+
110112
# check that this conv's weights did not change
111113
self.unit_test.assertTrue(np.all(conv_layers[0].get_quantized_weights()['kernel'].numpy() == self.conv_w))
112-
# check that this conv's weights did not change
113-
self.unit_test.assertTrue(np.all(conv_layers[2].kernel == self.conv_w))
114114
holder_layers = get_layers_from_model_by_type(quantized_model, KerasActivationQuantizationHolder)
115115
self.unit_test.assertTrue(holder_layers[1].activation_holder_quantizer.get_config()['num_bits'] == 16)
116116
self.unit_test.assertTrue(holder_layers[2].activation_holder_quantizer.get_config()['num_bits'] == self.activation_n_bits)

tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def test_two_conv_net_compose_after_split_activation_only(self):
214214

215215
graph.skip_validation_check = False
216216

217-
self._verify_two_conv_with_split_test(graph, v_graph, 9, 3)
217+
self._verify_two_conv_with_split_test(graph, v_graph, 3, 3)
218218

219219
def test_all_weights_layers_composition(self):
220220
in_model = multiple_weights_nodes_model()

tests_pytest/common_tests/unit_tests/core/quantization/test_node_quantization_config.py renamed to tests_pytest/common_tests/unit_tests/core/quantization/test_node_activation_quantization_config.py

File renamed without changes.

tests_pytest/keras_tests/integration_tests/core/fusion/test_graph_with_fusing_metadata_keras.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,10 @@
1919

2020
import keras
2121

22-
2322
class TestGraphWithFusionMetadataKeras(BaseGraphWithFusingMetadataTest, KerasFwMixin):
2423

2524
layer_class_relu = keras.layers.ReLU
2625

27-
def test_disable_act_quantization(self, graph_with_fusion_metadata):
28-
super().test_disable_act_quantization(graph_with_fusion_metadata)
29-
3026
def _data_gen(self):
3127
return self.get_basic_data_gen(shapes=[(1, 3, 5, 5)])()
3228

0 commit comments

Comments
 (0)