Skip to content

Commit 8ee3002

Browse files
committed
Correcting comments on pull requests_2
1 parent 1f30c99 commit 8ee3002

File tree

1 file changed

+116
-112
lines changed

1 file changed

+116
-112
lines changed

tests_pytest/common_tests/core/common/quantization/test_manual_bitwidth_selection.py

Lines changed: 116 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from model_compression_toolkit.core.common.graph.edge import Edge
99
from tests_pytest.test_util.graph_builder_utils import build_node
1010

11-
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
11+
TEST_KERNEL = 'kernel'
12+
TEST_BIAS = 'bias'
1213

1314
### dummy layer classes
1415
class Conv2D:
@@ -29,11 +30,11 @@ class Dense:
2930
### test model
3031
def get_test_graph():
3132
n1 = build_node('input', layer_class=InputLayer)
32-
conv1 = build_node('conv1', layer_class=Conv2D)
33+
conv1 = build_node('conv1', layer_class=Conv2D, canonical_weights={TEST_KERNEL: [1,2], TEST_BIAS: [3,4]})
3334
add1 = build_node('add1', layer_class=Add)
3435
conv2 = build_node('conv2', layer_class=Conv2D)
3536
bn1 = build_node('bn1', layer_class=BatchNormalization)
36-
relu = build_node('relu1', layer_class=ReLU)
37+
relu = build_node('relu1', layer_class=ReLU, canonical_weights={TEST_KERNEL: [1,2], TEST_BIAS: [3,4]})
3738
add2 = build_node('add2', layer_class=Add)
3839
flatten = build_node('flatten', layer_class=Flatten)
3940
fc = build_node('fc', layer_class=Dense)
@@ -55,41 +56,25 @@ def get_test_graph():
5556
return graph
5657

5758
class TestBitWidthConfig:
58-
# test case
59-
setter_test_input_0 = {"activation": (None, None),
60-
"weights": (None, None, None)}
61-
setter_test_input_1 = {"activation": (NodeTypeFilter(ReLU), [16]),
62-
"weights": (None, None, None)}
63-
setter_test_input_2 = {"activation": (None, None),
64-
"weights": (NodeNameFilter("conv2"), [8], KERNEL_ATTR)}
65-
setter_test_input_3 = {"activation": (NodeTypeFilter(ReLU), [16]),
66-
"weights": (NodeNameFilter("conv2"), [8], KERNEL_ATTR)}
67-
setter_test_input_4 = {"activation": ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8]),
68-
"weights": ([NodeTypeFilter(Conv2D), NodeNameFilter("fc")], [16, 2], [KERNEL_ATTR, BIAS_ATTR])}
69-
70-
setter_test_expected_0 = {"activation": (None, None),
71-
"weights": (None, None, None)}
72-
setter_test_expected_1 = {"activation": ([NodeTypeFilter, ReLU, 16]),
73-
"weights": (None, None, None)}
74-
setter_test_expected_2 = {"activation": (None, None),
75-
"weights": ([NodeNameFilter, "conv2", 8, KERNEL_ATTR]) }
76-
setter_test_expected_3 = {"activation": ([NodeTypeFilter, ReLU, 16]),
77-
"weights": ([NodeNameFilter, "conv2", 8, KERNEL_ATTR])}
78-
setter_test_expected_4 = {"activation": ([NodeTypeFilter, ReLU, 16], [NodeNameFilter, "conv1", 8]),
79-
"weights": ([NodeTypeFilter, Conv2D, 16, KERNEL_ATTR], [NodeNameFilter, "fc", 2, BIAS_ATTR])}
80-
81-
82-
# test : BitWidthConfig set_manual_activation_bit_width, set_manual_weights_bit_width
59+
# test case for set_manual_activation_bit_width
60+
test_input_0 = (None, None)
61+
test_input_1 = (NodeTypeFilter(ReLU), 16)
62+
test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16])
63+
test_input_3 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8])
64+
65+
test_expected_0 = ("The filters cannot be None.", None)
66+
test_expected_1 = (NodeTypeFilter, ReLU, 16)
67+
test_expected_2 = ([NodeTypeFilter, ReLU, 16], [NodeNameFilter, "conv1", 16])
68+
test_expected_3 = ([NodeTypeFilter, ReLU, 16], [NodeNameFilter, "conv1", 8])
69+
8370
@pytest.mark.parametrize(("inputs", "expected"), [
84-
(setter_test_input_0, setter_test_expected_0),
85-
(setter_test_input_1, setter_test_expected_1),
86-
(setter_test_input_2, setter_test_expected_2),
87-
(setter_test_input_3, setter_test_expected_3),
88-
(setter_test_input_4, setter_test_expected_4),
71+
(test_input_0, test_expected_0),
72+
(test_input_1, test_expected_1),
73+
(test_input_2, test_expected_2),
74+
(test_input_3, test_expected_3),
8975
])
90-
def test_bit_width_config_setter(self, inputs, expected):
91-
92-
def check_param(mb_cfg, exp):
76+
def test_set_manual_activation_bit_width(self, inputs, expected):
77+
def check_param_for_activation(mb_cfg, exp):
9378
### check setting config class (expected ManualBitWidthSelection)
9479
assert type(mb_cfg) == ManualBitWidthSelection
9580

@@ -106,8 +91,40 @@ def check_param(mb_cfg, exp):
10691
else:
10792
assert mb_cfg.filter is None
10893

109-
def check_param_for_weights(mb_cfg, exp):
110-
### check setting config class (expected ManualBitWidthSelection)
94+
manual_bit_cfg = BitWidthConfig()
95+
try:
96+
manual_bit_cfg.set_manual_activation_bit_width(inputs[0], inputs[1])
97+
### check Activation
98+
if len(manual_bit_cfg.manual_activation_bit_width_selection_list) == 1:
99+
for a_mb_cfg in manual_bit_cfg.manual_activation_bit_width_selection_list:
100+
print(a_mb_cfg, expected)
101+
check_param_for_activation(a_mb_cfg, expected)
102+
else:
103+
for idx, a_mb_cfg in enumerate(manual_bit_cfg.manual_activation_bit_width_selection_list):
104+
check_param_for_activation(a_mb_cfg, expected[idx])
105+
except Exception as e:
106+
assert str(e) == expected[0]
107+
108+
# test case for set_manual_weights_bit_width
109+
test_input_0 = (None, None, None)
110+
test_input_1 = (NodeTypeFilter(ReLU), 16, TEST_KERNEL)
111+
test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16], [TEST_KERNEL])
112+
test_input_3 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8], [TEST_KERNEL, TEST_BIAS])
113+
114+
test_expected_0 = ("The filters cannot be None.", None, None)
115+
test_expected_1 = (NodeTypeFilter, ReLU, 16, TEST_KERNEL)
116+
test_expected_2 = ([NodeTypeFilter, ReLU, 16, TEST_KERNEL], [NodeNameFilter, "conv1", 16, TEST_KERNEL])
117+
test_expected_3 = ([NodeTypeFilter, ReLU, 16, TEST_KERNEL], [NodeNameFilter, "conv1", 8, TEST_BIAS])
118+
119+
@pytest.mark.parametrize(("inputs", "expected"), [
120+
(test_input_0, test_expected_0),
121+
(test_input_1, test_expected_1),
122+
(test_input_2, test_expected_2),
123+
(test_input_3, test_expected_3),
124+
])
125+
def test_set_manual_weights_bit_width(self, inputs, expected):
126+
def check_param_weights(mb_cfg, exp):
127+
### check setting config class (expected ManualWeightsBitWidthSelection)
111128
assert type(mb_cfg) == ManualWeightsBitWidthSelection
112129

113130
### check setting filter for NodeFilter and NodeInfo
@@ -118,94 +135,81 @@ def check_param_for_weights(mb_cfg, exp):
118135
elif isinstance(mb_cfg.filter, NodeNameFilter):
119136
assert mb_cfg.filter.node_name == exp[1]
120137

121-
### check setting bit_width
138+
### check setting bit_width and attr
122139
assert mb_cfg.bit_width == exp[2]
123140
assert mb_cfg.attr == exp[3]
124141
else:
125142
assert mb_cfg.filter is None
126143

127-
activation = inputs["activation"]
128-
weights = inputs["weights"]
144+
manual_bit_cfg = BitWidthConfig()
145+
try:
146+
manual_bit_cfg.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2])
147+
### check weights
148+
if len(manual_bit_cfg.manual_weights_bit_width_selection_list) == 1:
149+
for a_mb_cfg in manual_bit_cfg.manual_weights_bit_width_selection_list:
150+
print(a_mb_cfg, expected)
151+
check_param_weights(a_mb_cfg, expected)
152+
else:
153+
for idx, a_mb_cfg in enumerate(manual_bit_cfg.manual_weights_bit_width_selection_list):
154+
check_param_weights(a_mb_cfg, expected[idx])
155+
except Exception as e:
156+
assert str(e) == expected[0]
129157

130-
activation_expected = expected["activation"]
131-
weights_expected = expected["weights"]
158+
# test case for get_nodes_to_manipulate_activation_bit_widths
159+
test_input_0 = (NodeTypeFilter(ReLU), 16)
160+
test_input_1 = (NodeNameFilter('relu1'), 16)
161+
test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8])
132162

133-
manual_bit_cfg = BitWidthConfig()
163+
test_expected_0 = ({"ReLU:relu1": 16})
164+
test_expected_1 = ({"ReLU:relu1": 16})
165+
test_expected_2 = ({"ReLU:relu1": 16, "Conv2D:conv1": 8})
134166

135-
manual_bit_cfg.set_manual_activation_bit_width(activation[0], activation[1])
136-
manual_bit_cfg.set_manual_weights_bit_width(weights[0], weights[1], weights[2])
137-
138-
### check got object instance
139-
assert isinstance(manual_bit_cfg, BitWidthConfig)
140-
141-
### check Activation
142-
if len(manual_bit_cfg.manual_activation_bit_width_selection_list) == 1:
143-
for a_mb_cfg in manual_bit_cfg.manual_activation_bit_width_selection_list:
144-
check_param(a_mb_cfg, activation_expected)
145-
else:
146-
for idx, a_mb_cfg in enumerate(manual_bit_cfg.manual_activation_bit_width_selection_list):
147-
check_param(a_mb_cfg, activation_expected[idx])
148-
149-
### check Weights
150-
if len(manual_bit_cfg.manual_weights_bit_width_selection_list) == 1:
151-
for w_mb_cfg in manual_bit_cfg.manual_weights_bit_width_selection_list:
152-
check_param_for_weights(w_mb_cfg, weights_expected)
153-
else:
154-
for idx, w_mb_cfg in enumerate(manual_bit_cfg.manual_weights_bit_width_selection_list):
155-
check_param_for_weights(w_mb_cfg, weights_expected[idx])
156-
157-
158-
### test case
159-
### Note: setter inputs reuse getters test inputs
160-
getter_test_expected_0 = {"activation":{},
161-
"weights": {}}
162-
getter_test_expected_1 = {"activation":{"ReLU:relu1": 16},
163-
"weights": {}}
164-
getter_test_expected_2 = {"activation":{},
165-
"weights": {"Conv2D:conv2": [8, KERNEL_ATTR]}}
166-
getter_test_expected_3 = {"activation": {"ReLU:relu1": 16},
167-
"weights": {"Conv2D:conv2": [8, KERNEL_ATTR]}}
168-
getter_test_expected_4 = {"activation": {"ReLU:relu1": 16, "Conv2D:conv1": 8},
169-
"weights": {"Conv2D:conv1": [16, KERNEL_ATTR], "Conv2D:conv2": [16, KERNEL_ATTR], "Dense:fc": [2, BIAS_ATTR]}}
170-
171-
# test : BitWidthConfig get_nodes_to_manipulate_bit_widths
172167
@pytest.mark.parametrize(("inputs", "expected"), [
173-
(setter_test_input_0, getter_test_expected_0),
174-
(setter_test_input_1, getter_test_expected_1),
175-
(setter_test_input_2, getter_test_expected_2),
176-
(setter_test_input_3, getter_test_expected_3),
177-
(setter_test_input_4, getter_test_expected_4),
168+
(test_input_0, test_expected_0),
169+
(test_input_1, test_expected_1),
170+
(test_input_2, test_expected_2),
178171
])
179-
def test_bit_width_config_getter(self, inputs, expected):
172+
def test_get_nodes_to_manipulate_activation_bit_widths(self, inputs, expected):
173+
fl_list = inputs[0] if isinstance(inputs[0], list) else [inputs[0]]
174+
bw_list = inputs[1] if isinstance(inputs[1], list) else [inputs[1]]
175+
176+
mbws_config = []
177+
for fl, bw in zip(fl_list, bw_list):
178+
mbws_config.append(ManualBitWidthSelection(fl, bw))
179+
manual_bit_cfg = BitWidthConfig(manual_activation_bit_width_selection_list=mbws_config)
180180

181181
graph = get_test_graph()
182+
get_manual_bit_dict_activation = manual_bit_cfg.get_nodes_to_manipulate_activation_bit_widths(graph)
183+
for idx, (key, val) in enumerate(get_manual_bit_dict_activation.items()):
184+
assert str(key) == list(expected.keys())[idx]
185+
assert val == list(expected.values())[idx]
182186

183-
activation = inputs["activation"]
184-
weights = inputs["weights"]
187+
# test case for get_nodes_to_manipulate_weights_bit_widths
188+
test_input_0 = (NodeTypeFilter(ReLU), 16, TEST_KERNEL)
189+
test_input_1 = (NodeNameFilter('relu1'), 16, TEST_BIAS)
190+
test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8], [TEST_KERNEL, TEST_BIAS])
185191

186-
activation_expected = expected["activation"]
187-
weights_expected = expected["weights"]
192+
test_expected_0 = ({"ReLU:relu1": [16, TEST_KERNEL]})
193+
test_expected_1 = ({"ReLU:relu1": [16, TEST_BIAS]})
194+
test_expected_2 = ({"ReLU:relu1": [16, TEST_KERNEL], "Conv2D:conv1": [8, TEST_BIAS]})
188195

189-
manual_bit_cfg = BitWidthConfig()
190-
if activation[0] is not None:
191-
manual_bit_cfg.set_manual_activation_bit_width(activation[0], activation[1])
192-
if weights[0] is not None:
193-
manual_bit_cfg.set_manual_weights_bit_width(weights[0], weights[1], weights[2])
196+
@pytest.mark.parametrize(("inputs", "expected"), [
197+
(test_input_0, test_expected_0),
198+
(test_input_1, test_expected_1),
199+
(test_input_2, test_expected_2),
200+
])
201+
def test_get_nodes_to_manipulate_weights_bit_widths(self, inputs, expected):
202+
fl_list = inputs[0] if isinstance(inputs[0], list) else [inputs[0]]
203+
bw_list = inputs[1] if isinstance(inputs[1], list) else [inputs[1]]
204+
at_list = inputs[2] if isinstance(inputs[2], list) else [inputs[2]]
194205

195-
get_manual_bit_dict_activation = manual_bit_cfg.get_nodes_to_manipulate_activation_bit_widths(graph)
196-
get_manual_bit_dict_weights = manual_bit_cfg.get_nodes_to_manipulate_weights_bit_widths(graph)
206+
manual_weights_bit_width_config = []
207+
for fl, bw, at in zip(fl_list, bw_list, at_list):
208+
manual_weights_bit_width_config.append(ManualWeightsBitWidthSelection(fl, bw, at))
209+
manual_bit_cfg = BitWidthConfig(manual_weights_bit_width_selection_list=manual_weights_bit_width_config)
197210

198-
if activation[0] is not None:
199-
for idx, (key, val) in enumerate(get_manual_bit_dict_activation.items()):
200-
assert str(key) == list(activation_expected.keys())[idx]
201-
assert val == list(activation_expected.values())[idx]
202-
else:
203-
assert get_manual_bit_dict_activation == activation_expected
204-
205-
if weights[0] is not None:
206-
for idx, (key, val) in enumerate(get_manual_bit_dict_weights.items()):
207-
assert str(key) == list(weights_expected.keys())[idx]
208-
assert val[0] == list(weights_expected.values())[idx][0]
209-
assert val[1] == list(weights_expected.values())[idx][1]
210-
else:
211-
assert get_manual_bit_dict_weights == weights_expected
211+
graph = get_test_graph()
212+
get_manual_bit_dict_weights = manual_bit_cfg.get_nodes_to_manipulate_weights_bit_widths(graph)
213+
for idx, (key, val) in enumerate(get_manual_bit_dict_weights.items()):
214+
assert str(key) == list(expected.keys())[idx]
215+
assert val == list(expected.values())[idx]

0 commit comments

Comments
 (0)