Skip to content

Commit 37d2dcf

Browse files
committed
fixing for manual weights selection bitwidth(kernel,bias)
1 parent c51a200 commit 37d2dcf

File tree

2 files changed

+41
-37
lines changed

2 files changed

+41
-37
lines changed

model_compression_toolkit/core/common/quantization/bit_width_config.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ManualBitWidthSelection:
3131
bit_width (int): The bit width to be applied to the selected nodes.
3232
"""
3333
filter: BaseNodeMatcher
34-
bit_width: int = 0
34+
bit_width: int
3535

3636
@dataclass
3737
class ManualWeightsBitWidthSelection(ManualBitWidthSelection):
@@ -40,26 +40,31 @@ class ManualWeightsBitWidthSelection(ManualBitWidthSelection):
4040
4141
Attributes:
4242
filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
43-
bit_width (int): The bit width to be applied to the selected nodes.
43+
attr (str): The attribute used to select nodes for bit width manipulation. [KERNEL_ATTR|BIAS_ATTR]
4444
"""
45-
kernel_bit_width: int = 0
46-
bias_bit_width: int = 0
45+
attr: str
4746

48-
def _expand_to_list_filter_and_bit_width(
47+
def _expand_to_list_core(
4948
filters: Union[List[BaseNodeMatcher]],
50-
bit_widths: Union[List[int], int]
51-
):
52-
filters = [filters] if not isinstance(filters, list) else filters
53-
bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths
54-
if len(bit_widths) > 1 and len(bit_widths) != len(filters):
55-
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(bit_widths)} "
49+
vals: Union[List[any], any]):
50+
vals = [vals] if not isinstance(vals, list) else vals
51+
if len(vals) > 1 and len(vals) != len(filters):
52+
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(vals)} "
5653
f"must match the number of filters {len(filters)}, or a single bit_width value "
5754
f"should be provided for all filters.")
58-
elif len(bit_widths) == 1 and len(filters) > 1:
59-
bit_widths = [bit_widths[0] for f in filters]
55+
elif len(vals) == 1 and len(filters) > 1:
56+
vals = [vals[0] for f in filters]
57+
return vals
6058

61-
return bit_widths, filters
59+
def _expand_to_list_filter_and_bit_width(
60+
filters: Union[List[BaseNodeMatcher]],
61+
bit_widths: Union[List[int], int],
62+
attrs: Union[List[str], str] = None):
63+
filters = [filters] if not isinstance(filters, list) else filters
64+
bit_widths = _expand_to_list_core(filters, bit_widths)
65+
attrs = _expand_to_list_core(filters, attrs)
6266

67+
return attrs, bit_widths, filters
6368

6469
def _make_nodes_to_change_bit_width(graph, manual_bit_width_selection_list):
6570
unit_nodes_to_change_bit_width = {}
@@ -77,7 +82,7 @@ def _make_nodes_to_change_bit_width(graph, manual_bit_width_selection_list):
7782
if isinstance(manual_bit_width_selection_list, ManualBitWidthSelection):
7883
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
7984
elif isinstance(manual_bit_width_selection_list, ManualWeightsBitWidthSelection):
80-
unit_nodes_to_change_bit_width.update({n: [manual_bit_width_selection.kernel_bit_width, manual_bit_width_selection.bias_bit_width]})
85+
unit_nodes_to_change_bit_width.update({n: [manual_bit_width_selection.bit_width, manual_bit_width_selection.attr]})
8186

8287

8388
return unit_nodes_to_change_bit_width
@@ -109,30 +114,27 @@ def set_manual_activation_bit_width(self,
109114
bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
110115
If a single value is given it will be applied to all the filters
111116
"""
112-
bit_widths, filters = _expand_to_list_filter_and_bit_width(filters, bit_widths)
117+
_, bit_widths, filters = _expand_to_list_filter_and_bit_width(filters, bit_widths)
113118
for bit_width, filter in zip (bit_widths, filters):
114119
self.manual_activation_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]
115120

116121
def set_manual_weights_bit_width(self,
117122
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
118-
kernel_bit_widths: Union[List[int], int],
119-
bias_bit_widths: Union[List[int], int]
123+
bit_widths: Union[List[int], int],
124+
attrs: Union[List[str], str]
120125
):
121126
"""
122127
Add a manual bit-width selection for weights to the configuration.
123128
124129
Args:
125130
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
126-
kernel_bit_widths (Union[List[int], int]): The bit widths for kernel to be applied to the selected nodes.
127-
bias_bit_widths (Union[List[int], int]): The bit widths for bias to be applied to the selected nodes.
131+
bit_widths (Union[List[int], int]): The bit widths for kernel to be applied to the selected nodes.
132+
attrs (Union[List[str], str]): The attributes used to select nodes for bit width manipulation. [KERNEL_ATTR|BIAS_ATTR]
128133
If a single value is given it will be applied to all the filters
129134
"""
130-
kernel_bit_widths, filters = _expand_to_list_filter_and_bit_width(filters, kernel_bit_widths)
131-
bias_bit_widths, filters = _expand_to_list_filter_and_bit_width(filters, bias_bit_widths)
132-
print("kernel_bit_widths", kernel_bit_widths)
133-
print("bias_bit_widths", bias_bit_widths)
134-
for kernel_bit_width, bias_bit_width, filter in zip (kernel_bit_widths, bias_bit_widths, filters):
135-
self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, kernel_bit_width=kernel_bit_width, bias_bit_width=bias_bit_width)]
135+
attrs, bit_widths, filters = _expand_to_list_filter_and_bit_width(filters, bit_widths, attrs)
136+
for attr, bit_width, filter in zip (attrs, bit_widths, filters):
137+
self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, bit_width, attr)]
136138

137139
def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> NodesToChangeBitWidth:
138140
"""

tests_pytest/common_tests/core/common/quantization/test_manual_weights_bitwidth_selection.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from model_compression_toolkit.core.common.graph.edge import Edge
99
from tests_pytest.test_util.graph_builder_utils import build_node
1010

11+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
12+
1113
### dummy layer classes
1214
class Conv2D:
1315
pass
@@ -61,22 +63,22 @@ class TestBitWidthConfig:
6163
setter_test_input_1 = {"activation": (NodeTypeFilter(ReLU), [16]),
6264
"weights": (None, None, None)}
6365
setter_test_input_2 = {"activation": (None, None),
64-
"weights": (NodeNameFilter("conv2"), [8], 16)}
66+
"weights": (NodeNameFilter("conv2"), [8], KERNEL_ATTR)}
6567
setter_test_input_3 = {"activation": (NodeTypeFilter(ReLU), [16]),
66-
"weights": (NodeNameFilter("conv2"), [8], 16)}
68+
"weights": (NodeNameFilter("conv2"), [8], KERNEL_ATTR)}
6769
setter_test_input_4 = {"activation": ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8]),
68-
"weights": ([NodeTypeFilter(Conv2D), NodeNameFilter("fc")], [16, 2], [8, 4])}
70+
"weights": ([NodeTypeFilter(Conv2D), NodeNameFilter("fc")], [16, 2], [KERNEL_ATTR, BIAS_ATTR])}
6971

7072
setter_test_expected_0 = {"activation": (None, None),
7173
"weights": (None, None, None)}
7274
setter_test_expected_1 = {"activation": ([NodeTypeFilter, ReLU, 16]),
7375
"weights": (None, None, None)}
7476
setter_test_expected_2 = {"activation": (None, None),
75-
"weights": ([NodeNameFilter, "conv2", 8, 16]) }
77+
"weights": ([NodeNameFilter, "conv2", 8, KERNEL_ATTR]) }
7678
setter_test_expected_3 = {"activation": ([NodeTypeFilter, ReLU, 16]),
77-
"weights": ([NodeNameFilter, "conv2", 8, 16])}
79+
"weights": ([NodeNameFilter, "conv2", 8, KERNEL_ATTR])}
7880
setter_test_expected_4 = {"activation": ([NodeTypeFilter, ReLU, 16], [NodeNameFilter, "conv1", 8]),
79-
"weights": ([NodeTypeFilter, Conv2D, 16, 8], [NodeNameFilter, "fc", 2, 4])}
81+
"weights": ([NodeTypeFilter, Conv2D, 16, KERNEL_ATTR], [NodeNameFilter, "fc", 2, BIAS_ATTR])}
8082

8183

8284
# test : BitWidthConfig set_manual_activation_bit_width, set_manual_weights_bit_width
@@ -119,8 +121,8 @@ def check_param_for_weights(mb_cfg, exp):
119121
assert mb_cfg.filter.node_name == exp[1]
120122

121123
### check setting bit_width
122-
assert mb_cfg.kernel_bit_width == exp[2]
123-
assert mb_cfg.bias_bit_width == exp[3]
124+
assert mb_cfg.bit_width == exp[2]
125+
assert mb_cfg.attr == exp[3]
124126
else:
125127
assert mb_cfg.filter is None
126128

@@ -163,11 +165,11 @@ def check_param_for_weights(mb_cfg, exp):
163165
getter_test_expected_1 = {"activation":{"ReLU:relu1": 16},
164166
"weights": {}}
165167
getter_test_expected_2 = {"activation":{},
166-
"weights": {"Conv2D:conv2": [8, 16]}}
168+
"weights": {"Conv2D:conv2": [8, KERNEL_ATTR]}}
167169
getter_test_expected_3 = {"activation": {"ReLU:relu1": 16},
168-
"weights": {"Conv2D:conv2": [8, 16]}}
170+
"weights": {"Conv2D:conv2": [8, KERNEL_ATTR]}}
169171
getter_test_expected_4 = {"activation": {"ReLU:relu1": 16, "Conv2D:conv1": 8},
170-
"weights": {"Conv2D:conv1": [16, 8], "Conv2D:conv2": [16, 8], "Dense:fc": [2, 4]}}
172+
"weights": {"Conv2D:conv1": [16, KERNEL_ATTR], "Conv2D:conv2": [16, KERNEL_ATTR], "Dense:fc": [2, BIAS_ATTR]}}
171173

172174
# test : BitWidthConfig get_nodes_to_manipulate_bit_widths
173175
@pytest.mark.parametrize(("inputs", "expected"), [

0 commit comments

Comments
 (0)