88from model_compression_toolkit .core .common .graph .edge import Edge
99from tests_pytest .test_util .graph_builder_utils import build_node
1010
11- from model_compression_toolkit .target_platform_capabilities .constants import KERNEL_ATTR , BIAS_ATTR
11+ TEST_KERNEL = 'kernel'
12+ TEST_BIAS = 'bias'
1213
1314### dummy layer classes
1415class Conv2D :
@@ -29,11 +30,11 @@ class Dense:
2930### test model
3031def get_test_graph ():
3132 n1 = build_node ('input' , layer_class = InputLayer )
32- conv1 = build_node ('conv1' , layer_class = Conv2D )
33+ conv1 = build_node ('conv1' , layer_class = Conv2D , canonical_weights = { TEST_KERNEL : [ 1 , 2 ], TEST_BIAS : [ 3 , 4 ]} )
3334 add1 = build_node ('add1' , layer_class = Add )
3435 conv2 = build_node ('conv2' , layer_class = Conv2D )
3536 bn1 = build_node ('bn1' , layer_class = BatchNormalization )
36- relu = build_node ('relu1' , layer_class = ReLU )
37+ relu = build_node ('relu1' , layer_class = ReLU , canonical_weights = { TEST_KERNEL : [ 1 , 2 ], TEST_BIAS : [ 3 , 4 ]} )
3738 add2 = build_node ('add2' , layer_class = Add )
3839 flatten = build_node ('flatten' , layer_class = Flatten )
3940 fc = build_node ('fc' , layer_class = Dense )
@@ -55,41 +56,25 @@ def get_test_graph():
5556 return graph
5657
5758class TestBitWidthConfig :
58- # test case
59- setter_test_input_0 = {"activation" : (None , None ),
60- "weights" : (None , None , None )}
61- setter_test_input_1 = {"activation" : (NodeTypeFilter (ReLU ), [16 ]),
62- "weights" : (None , None , None )}
63- setter_test_input_2 = {"activation" : (None , None ),
64- "weights" : (NodeNameFilter ("conv2" ), [8 ], KERNEL_ATTR )}
65- setter_test_input_3 = {"activation" : (NodeTypeFilter (ReLU ), [16 ]),
66- "weights" : (NodeNameFilter ("conv2" ), [8 ], KERNEL_ATTR )}
67- setter_test_input_4 = {"activation" : ([NodeTypeFilter (ReLU ), NodeNameFilter ("conv1" )], [16 , 8 ]),
68- "weights" : ([NodeTypeFilter (Conv2D ), NodeNameFilter ("fc" )], [16 , 2 ], [KERNEL_ATTR , BIAS_ATTR ])}
69-
70- setter_test_expected_0 = {"activation" : (None , None ),
71- "weights" : (None , None , None )}
72- setter_test_expected_1 = {"activation" : ([NodeTypeFilter , ReLU , 16 ]),
73- "weights" : (None , None , None )}
74- setter_test_expected_2 = {"activation" : (None , None ),
75- "weights" : ([NodeNameFilter , "conv2" , 8 , KERNEL_ATTR ]) }
76- setter_test_expected_3 = {"activation" : ([NodeTypeFilter , ReLU , 16 ]),
77- "weights" : ([NodeNameFilter , "conv2" , 8 , KERNEL_ATTR ])}
78- setter_test_expected_4 = {"activation" : ([NodeTypeFilter , ReLU , 16 ], [NodeNameFilter , "conv1" , 8 ]),
79- "weights" : ([NodeTypeFilter , Conv2D , 16 , KERNEL_ATTR ], [NodeNameFilter , "fc" , 2 , BIAS_ATTR ])}
80-
81-
82- # test : BitWidthConfig set_manual_activation_bit_width, set_manual_weights_bit_width
59+ # test case for set_manual_activation_bit_width
60+ test_input_0 = (None , None )
61+ test_input_1 = (NodeTypeFilter (ReLU ), 16 )
62+ test_input_2 = ([NodeTypeFilter (ReLU ), NodeNameFilter ("conv1" )], [16 ])
63+ test_input_3 = ([NodeTypeFilter (ReLU ), NodeNameFilter ("conv1" )], [16 , 8 ])
64+
65+ test_expected_0 = ("The filters cannot be None." , None )
66+ test_expected_1 = (NodeTypeFilter , ReLU , 16 )
67+ test_expected_2 = ([NodeTypeFilter , ReLU , 16 ], [NodeNameFilter , "conv1" , 16 ])
68+ test_expected_3 = ([NodeTypeFilter , ReLU , 16 ], [NodeNameFilter , "conv1" , 8 ])
69+
8370 @pytest .mark .parametrize (("inputs" , "expected" ), [
84- (setter_test_input_0 , setter_test_expected_0 ),
85- (setter_test_input_1 , setter_test_expected_1 ),
86- (setter_test_input_2 , setter_test_expected_2 ),
87- (setter_test_input_3 , setter_test_expected_3 ),
88- (setter_test_input_4 , setter_test_expected_4 ),
71+ (test_input_0 , test_expected_0 ),
72+ (test_input_1 , test_expected_1 ),
73+ (test_input_2 , test_expected_2 ),
74+ (test_input_3 , test_expected_3 ),
8975 ])
90- def test_bit_width_config_setter (self , inputs , expected ):
91-
92- def check_param (mb_cfg , exp ):
76+ def test_set_manual_activation_bit_width (self , inputs , expected ):
77+ def check_param_for_activation (mb_cfg , exp ):
9378 ### check setting config class (expected ManualBitWidthSelection)
9479 assert type (mb_cfg ) == ManualBitWidthSelection
9580
@@ -106,8 +91,40 @@ def check_param(mb_cfg, exp):
10691 else :
10792 assert mb_cfg .filter is None
10893
109- def check_param_for_weights (mb_cfg , exp ):
110- ### check setting config class (expected ManualBitWidthSelection)
94+ manual_bit_cfg = BitWidthConfig ()
95+ try :
96+ manual_bit_cfg .set_manual_activation_bit_width (inputs [0 ], inputs [1 ])
97+ ### check Activation
98+ if len (manual_bit_cfg .manual_activation_bit_width_selection_list ) == 1 :
99+ for a_mb_cfg in manual_bit_cfg .manual_activation_bit_width_selection_list :
100+ print (a_mb_cfg , expected )
101+ check_param_for_activation (a_mb_cfg , expected )
102+ else :
103+ for idx , a_mb_cfg in enumerate (manual_bit_cfg .manual_activation_bit_width_selection_list ):
104+ check_param_for_activation (a_mb_cfg , expected [idx ])
105+ except Exception as e :
106+ assert str (e ) == expected [0 ]
107+
108+ # test case for set_manual_weights_bit_width
109+ test_input_0 = (None , None , None )
110+ test_input_1 = (NodeTypeFilter (ReLU ), 16 , TEST_KERNEL )
111+ test_input_2 = ([NodeTypeFilter (ReLU ), NodeNameFilter ("conv1" )], [16 ], [TEST_KERNEL ])
112+ test_input_3 = ([NodeTypeFilter (ReLU ), NodeNameFilter ("conv1" )], [16 , 8 ], [TEST_KERNEL , TEST_BIAS ])
113+
114+ test_expected_0 = ("The filters cannot be None." , None , None )
115+ test_expected_1 = (NodeTypeFilter , ReLU , 16 , TEST_KERNEL )
116+ test_expected_2 = ([NodeTypeFilter , ReLU , 16 , TEST_KERNEL ], [NodeNameFilter , "conv1" , 16 , TEST_KERNEL ])
117+ test_expected_3 = ([NodeTypeFilter , ReLU , 16 , TEST_KERNEL ], [NodeNameFilter , "conv1" , 8 , TEST_BIAS ])
118+
119+ @pytest .mark .parametrize (("inputs" , "expected" ), [
120+ (test_input_0 , test_expected_0 ),
121+ (test_input_1 , test_expected_1 ),
122+ (test_input_2 , test_expected_2 ),
123+ (test_input_3 , test_expected_3 ),
124+ ])
125+ def test_set_manual_weights_bit_width (self , inputs , expected ):
126+ def check_param_weights (mb_cfg , exp ):
127+ ### check setting config class (expected ManualWeightsBitWidthSelection)
111128 assert type (mb_cfg ) == ManualWeightsBitWidthSelection
112129
113130 ### check setting filter for NodeFilter and NodeInfo
@@ -118,94 +135,81 @@ def check_param_for_weights(mb_cfg, exp):
118135 elif isinstance (mb_cfg .filter , NodeNameFilter ):
119136 assert mb_cfg .filter .node_name == exp [1 ]
120137
121- ### check setting bit_width
138+ ### check setting bit_width and attr
122139 assert mb_cfg .bit_width == exp [2 ]
123140 assert mb_cfg .attr == exp [3 ]
124141 else :
125142 assert mb_cfg .filter is None
126143
127- activation = inputs ["activation" ]
128- weights = inputs ["weights" ]
144+ manual_bit_cfg = BitWidthConfig ()
145+ try :
146+ manual_bit_cfg .set_manual_weights_bit_width (inputs [0 ], inputs [1 ], inputs [2 ])
147+ ### check weights
148+ if len (manual_bit_cfg .manual_weights_bit_width_selection_list ) == 1 :
149+ for a_mb_cfg in manual_bit_cfg .manual_weights_bit_width_selection_list :
150+ print (a_mb_cfg , expected )
151+ check_param_weights (a_mb_cfg , expected )
152+ else :
153+ for idx , a_mb_cfg in enumerate (manual_bit_cfg .manual_weights_bit_width_selection_list ):
154+ check_param_weights (a_mb_cfg , expected [idx ])
155+ except Exception as e :
156+ assert str (e ) == expected [0 ]
129157
130- activation_expected = expected ["activation" ]
131- weights_expected = expected ["weights" ]
158+ # test case for get_nodes_to_manipulate_activation_bit_widths
159+ test_input_0 = (NodeTypeFilter (ReLU ), 16 )
160+ test_input_1 = (NodeNameFilter ('relu1' ), 16 )
161+ test_input_2 = ([NodeTypeFilter (ReLU ), NodeNameFilter ("conv1" )], [16 , 8 ])
132162
133- manual_bit_cfg = BitWidthConfig ()
163+ test_expected_0 = ({"ReLU:relu1" : 16 })
164+ test_expected_1 = ({"ReLU:relu1" : 16 })
165+ test_expected_2 = ({"ReLU:relu1" : 16 , "Conv2D:conv1" : 8 })
134166
135- manual_bit_cfg .set_manual_activation_bit_width (activation [0 ], activation [1 ])
136- manual_bit_cfg .set_manual_weights_bit_width (weights [0 ], weights [1 ], weights [2 ])
137-
138- ### check got object instance
139- assert isinstance (manual_bit_cfg , BitWidthConfig )
140-
141- ### check Activation
142- if len (manual_bit_cfg .manual_activation_bit_width_selection_list ) == 1 :
143- for a_mb_cfg in manual_bit_cfg .manual_activation_bit_width_selection_list :
144- check_param (a_mb_cfg , activation_expected )
145- else :
146- for idx , a_mb_cfg in enumerate (manual_bit_cfg .manual_activation_bit_width_selection_list ):
147- check_param (a_mb_cfg , activation_expected [idx ])
148-
149- ### check Weights
150- if len (manual_bit_cfg .manual_weights_bit_width_selection_list ) == 1 :
151- for w_mb_cfg in manual_bit_cfg .manual_weights_bit_width_selection_list :
152- check_param_for_weights (w_mb_cfg , weights_expected )
153- else :
154- for idx , w_mb_cfg in enumerate (manual_bit_cfg .manual_weights_bit_width_selection_list ):
155- check_param_for_weights (w_mb_cfg , weights_expected [idx ])
156-
157-
158- ### test case
159- ### Note: setter inputs reuse getters test inputs
160- getter_test_expected_0 = {"activation" :{},
161- "weights" : {}}
162- getter_test_expected_1 = {"activation" :{"ReLU:relu1" : 16 },
163- "weights" : {}}
164- getter_test_expected_2 = {"activation" :{},
165- "weights" : {"Conv2D:conv2" : [8 , KERNEL_ATTR ]}}
166- getter_test_expected_3 = {"activation" : {"ReLU:relu1" : 16 },
167- "weights" : {"Conv2D:conv2" : [8 , KERNEL_ATTR ]}}
168- getter_test_expected_4 = {"activation" : {"ReLU:relu1" : 16 , "Conv2D:conv1" : 8 },
169- "weights" : {"Conv2D:conv1" : [16 , KERNEL_ATTR ], "Conv2D:conv2" : [16 , KERNEL_ATTR ], "Dense:fc" : [2 , BIAS_ATTR ]}}
170-
171- # test : BitWidthConfig get_nodes_to_manipulate_bit_widths
172167 @pytest .mark .parametrize (("inputs" , "expected" ), [
173- (setter_test_input_0 , getter_test_expected_0 ),
174- (setter_test_input_1 , getter_test_expected_1 ),
175- (setter_test_input_2 , getter_test_expected_2 ),
176- (setter_test_input_3 , getter_test_expected_3 ),
177- (setter_test_input_4 , getter_test_expected_4 ),
168+ (test_input_0 , test_expected_0 ),
169+ (test_input_1 , test_expected_1 ),
170+ (test_input_2 , test_expected_2 ),
178171 ])
179- def test_bit_width_config_getter (self , inputs , expected ):
172+ def test_get_nodes_to_manipulate_activation_bit_widths (self , inputs , expected ):
173+ fl_list = inputs [0 ] if isinstance (inputs [0 ], list ) else [inputs [0 ]]
174+ bw_list = inputs [1 ] if isinstance (inputs [1 ], list ) else [inputs [1 ]]
175+
176+ mbws_config = []
177+ for fl , bw in zip (fl_list , bw_list ):
178+ mbws_config .append (ManualBitWidthSelection (fl , bw ))
179+ manual_bit_cfg = BitWidthConfig (manual_activation_bit_width_selection_list = mbws_config )
180180
181181 graph = get_test_graph ()
182+ get_manual_bit_dict_activation = manual_bit_cfg .get_nodes_to_manipulate_activation_bit_widths (graph )
183+ for idx , (key , val ) in enumerate (get_manual_bit_dict_activation .items ()):
184+ assert str (key ) == list (expected .keys ())[idx ]
185+ assert val == list (expected .values ())[idx ]
182186
183- activation = inputs ["activation" ]
184- weights = inputs ["weights" ]
187+ # test case for get_nodes_to_manipulate_weights_bit_widths
188+ test_input_0 = (NodeTypeFilter (ReLU ), 16 , TEST_KERNEL )
189+ test_input_1 = (NodeNameFilter ('relu1' ), 16 , TEST_BIAS )
190+ test_input_2 = ([NodeTypeFilter (ReLU ), NodeNameFilter ("conv1" )], [16 , 8 ], [TEST_KERNEL , TEST_BIAS ])
185191
186- activation_expected = expected ["activation" ]
187- weights_expected = expected ["weights" ]
192+ test_expected_0 = ({"ReLU:relu1" : [16 , TEST_KERNEL ]})
193+ test_expected_1 = ({"ReLU:relu1" : [16 , TEST_BIAS ]})
194+ test_expected_2 = ({"ReLU:relu1" : [16 , TEST_KERNEL ], "Conv2D:conv1" : [8 , TEST_BIAS ]})
188195
189- manual_bit_cfg = BitWidthConfig ()
190- if activation [0 ] is not None :
191- manual_bit_cfg .set_manual_activation_bit_width (activation [0 ], activation [1 ])
192- if weights [0 ] is not None :
193- manual_bit_cfg .set_manual_weights_bit_width (weights [0 ], weights [1 ], weights [2 ])
196+ @pytest .mark .parametrize (("inputs" , "expected" ), [
197+ (test_input_0 , test_expected_0 ),
198+ (test_input_1 , test_expected_1 ),
199+ (test_input_2 , test_expected_2 ),
200+ ])
201+ def test_get_nodes_to_manipulate_weights_bit_widths (self , inputs , expected ):
202+ fl_list = inputs [0 ] if isinstance (inputs [0 ], list ) else [inputs [0 ]]
203+ bw_list = inputs [1 ] if isinstance (inputs [1 ], list ) else [inputs [1 ]]
204+ at_list = inputs [2 ] if isinstance (inputs [2 ], list ) else [inputs [2 ]]
194205
195- get_manual_bit_dict_activation = manual_bit_cfg .get_nodes_to_manipulate_activation_bit_widths (graph )
196- get_manual_bit_dict_weights = manual_bit_cfg .get_nodes_to_manipulate_weights_bit_widths (graph )
206+ manual_weights_bit_width_config = []
207+ for fl , bw , at in zip (fl_list , bw_list , at_list ):
208+ manual_weights_bit_width_config .append (ManualWeightsBitWidthSelection (fl , bw , at ))
209+ manual_bit_cfg = BitWidthConfig (manual_weights_bit_width_selection_list = manual_weights_bit_width_config )
197210
198- if activation [0 ] is not None :
199- for idx , (key , val ) in enumerate (get_manual_bit_dict_activation .items ()):
200- assert str (key ) == list (activation_expected .keys ())[idx ]
201- assert val == list (activation_expected .values ())[idx ]
202- else :
203- assert get_manual_bit_dict_activation == activation_expected
204-
205- if weights [0 ] is not None :
206- for idx , (key , val ) in enumerate (get_manual_bit_dict_weights .items ()):
207- assert str (key ) == list (weights_expected .keys ())[idx ]
208- assert val [0 ] == list (weights_expected .values ())[idx ][0 ]
209- assert val [1 ] == list (weights_expected .values ())[idx ][1 ]
210- else :
211- assert get_manual_bit_dict_weights == weights_expected
211+ graph = get_test_graph ()
212+ get_manual_bit_dict_weights = manual_bit_cfg .get_nodes_to_manipulate_weights_bit_widths (graph )
213+ for idx , (key , val ) in enumerate (get_manual_bit_dict_weights .items ()):
214+ assert str (key ) == list (expected .keys ())[idx ]
215+ assert val == list (expected .values ())[idx ]
0 commit comments