Skip to content

Commit d411b8d

Browse files
committed
fixed PR-FB for manual weights selection bitwidth
1 parent 37d2dcf commit d411b8d

File tree

2 files changed

+85
-72
lines changed

2 files changed

+85
-72
lines changed

model_compression_toolkit/core/common/quantization/bit_width_config.py

Lines changed: 76 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
2121
from model_compression_toolkit.logger import Logger
2222

23+
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
2324

2425
@dataclass
2526
class ManualBitWidthSelection:
@@ -40,68 +41,22 @@ class ManualWeightsBitWidthSelection(ManualBitWidthSelection):
4041
4142
Attributes:
4243
filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
43-
attr (str): The attribute used to select nodes for bit width manipulation. [KERNEL_ATTR|BIAS_ATTR]
44+
bit_width (int): The bit width to be applied to the selected nodes.
45+
attr (str): The filtered node's attributes to apply bit-width manipulation to.
4446
"""
45-
attr: str
46-
47-
def _expand_to_list_core(
48-
filters: Union[List[BaseNodeMatcher]],
49-
vals: Union[List[any], any]):
50-
vals = [vals] if not isinstance(vals, list) else vals
51-
if len(vals) > 1 and len(vals) != len(filters):
52-
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(vals)} "
53-
f"must match the number of filters {len(filters)}, or a single bit_width value "
54-
f"should be provided for all filters.")
55-
elif len(vals) == 1 and len(filters) > 1:
56-
vals = [vals[0] for f in filters]
57-
return vals
58-
59-
def _expand_to_list_filter_and_bit_width(
60-
filters: Union[List[BaseNodeMatcher]],
61-
bit_widths: Union[List[int], int],
62-
attrs: Union[List[str], str] = None):
63-
filters = [filters] if not isinstance(filters, list) else filters
64-
bit_widths = _expand_to_list_core(filters, bit_widths)
65-
attrs = _expand_to_list_core(filters, attrs)
66-
67-
return attrs, bit_widths, filters
68-
69-
def _make_nodes_to_change_bit_width(graph, manual_bit_width_selection_list):
70-
unit_nodes_to_change_bit_width = {}
71-
for manual_bit_width_selection in manual_bit_width_selection_list:
72-
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
73-
if len(filtered_nodes) == 0:
74-
Logger.critical(
75-
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
76-
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
77-
for n in filtered_nodes:
78-
# check if a manual configuration exists for this node
79-
if n in unit_nodes_to_change_bit_width:
80-
Logger.info(
81-
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
82-
if isinstance(manual_bit_width_selection_list, ManualBitWidthSelection):
83-
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
84-
elif isinstance(manual_bit_width_selection_list, ManualWeightsBitWidthSelection):
85-
unit_nodes_to_change_bit_width.update({n: [manual_bit_width_selection.bit_width, manual_bit_width_selection.attr]})
86-
87-
88-
return unit_nodes_to_change_bit_width
89-
90-
@dataclass
91-
class NodesToChangeBitWidth:
92-
activation_nodes_to_change_bit_width: Dict
93-
weights_nodes_to_change_bit_width: Dict
47+
attr: WeightAttrT
9448

9549
@dataclass
9650
class BitWidthConfig:
9751
"""
9852
Class to manage manual bit-width configurations.
9953
10054
Attributes:
101-
manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects defining manual bit-width configurations.
55+
manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects for activation defining manual bit-width configurations.
56+
manual_activation_bit_width_selection_list (List[ManualWeightsBitWidthSelection]): A list of ManualWeightsBitWidthSelection for weights objects defining manual bit-width configurations.
10257
"""
10358
manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = field(default_factory=list)
104-
manual_weights_bit_width_selection_list: List[ManualBitWidthSelection] = field(default_factory=list)
59+
manual_weights_bit_width_selection_list: List[ManualWeightsBitWidthSelection] = field(default_factory=list)
10560

10661
def set_manual_activation_bit_width(self,
10762
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
@@ -114,7 +69,7 @@ def set_manual_activation_bit_width(self,
11469
bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
11570
If a single value is given it will be applied to all the filters
11671
"""
117-
_, bit_widths, filters = _expand_to_list_filter_and_bit_width(filters, bit_widths)
72+
_, bit_widths, filters = self._expand_to_list_filter_and_bit_width(filters, bit_widths)
11873
for bit_width, filter in zip (bit_widths, filters):
11974
self.manual_activation_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]
12075

@@ -128,15 +83,15 @@ def set_manual_weights_bit_width(self,
12883
12984
Args:
13085
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
131-
bit_widths (Union[List[int], int]): The bit widths for kernel to be applied to the selected nodes.
132-
attrs (Union[List[str], str]): The attributes used to select nodes for bit width manipulation. [KERNEL_ATTR|BIAS_ATTR]
86+
bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes.
87+
attrs (Union[List[str], str]): The filtered node's attributes to apply bit-width manipulation to.
13388
If a single value is given it will be applied to all the filters
13489
"""
135-
attrs, bit_widths, filters = _expand_to_list_filter_and_bit_width(filters, bit_widths, attrs)
90+
attrs, bit_widths, filters = self._expand_to_list_filter_and_bit_width(filters, bit_widths, attrs)
13691
for attr, bit_width, filter in zip (attrs, bit_widths, filters):
13792
self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, bit_width, attr)]
13893

139-
def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> NodesToChangeBitWidth:
94+
def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
14095
"""
14196
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
14297
@@ -146,8 +101,68 @@ def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> NodesToChangeBitWi
146101
Returns:
147102
Dict: A dictionary mapping nodes to their new bit-widths.
148103
"""
149-
activation_nodes_to_change_bit_width = _make_nodes_to_change_bit_width(graph, self.manual_activation_bit_width_selection_list)
150-
weights_nodes_to_change_bit_width = _make_nodes_to_change_bit_width(graph, self.manual_weights_bit_width_selection_list)
104+
activation_nodes_to_change_bit_width = self._construct_node_to_new_bit_mapping(graph, self.manual_activation_bit_width_selection_list)
105+
106+
return activation_nodes_to_change_bit_width
107+
108+
def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
109+
"""
110+
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
151111
152-
nodes_to_change_bit_width = NodesToChangeBitWidth(activation_nodes_to_change_bit_width, weights_nodes_to_change_bit_width)
153-
return nodes_to_change_bit_width
112+
Args:
113+
graph (Graph): The graph containing the nodes to be filtered and manipulated.
114+
115+
Returns:
116+
Dict: A dictionary mapping nodes to their new bit-widths.
117+
"""
118+
weights_nodes_to_change_bit_width = self._construct_node_to_new_bit_mapping(graph, self.manual_weights_bit_width_selection_list)
119+
120+
return weights_nodes_to_change_bit_width
121+
122+
123+
124+
def _expand_to_list_core(
125+
self,
126+
filters: Union[List[BaseNodeMatcher]],
127+
vals: Union[List[any], any]):
128+
vals = [vals] if not isinstance(vals, list) else vals
129+
if len(vals) > 1 and len(vals) != len(filters):
130+
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(vals)} "
131+
f"must match the number of filters {len(filters)}, or a single bit_width value "
132+
f"should be provided for all filters.")
133+
elif len(vals) == 1 and len(filters) > 1:
134+
vals = [vals[0] for f in filters]
135+
return vals
136+
137+
def _expand_to_list_filter_and_bit_width(
138+
self,
139+
filters: Union[List[BaseNodeMatcher]],
140+
bit_widths: Union[List[int], int],
141+
attrs: Union[List[str], str] = None):
142+
filters = [filters] if not isinstance(filters, list) else filters
143+
bit_widths = self._expand_to_list_core(filters, bit_widths)
144+
attrs = self._expand_to_list_core(filters, attrs)
145+
146+
return attrs, bit_widths, filters
147+
148+
def _construct_node_to_new_bit_mapping(self, graph, manual_bit_width_selection_list):
149+
unit_nodes_to_change_bit_width = {}
150+
for manual_bit_width_selection in manual_bit_width_selection_list:
151+
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
152+
if len(filtered_nodes) == 0:
153+
Logger.critical(
154+
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
155+
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
156+
for n in filtered_nodes:
157+
if n.get_node_weights_attributes() is False:
158+
Logger.critical(f'The requested attribute to change the bit width for {n} is not existing.')
159+
# check if a manual configuration exists for this node
160+
if n in unit_nodes_to_change_bit_width:
161+
Logger.info(
162+
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
163+
if isinstance(manual_bit_width_selection_list, ManualBitWidthSelection):
164+
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
165+
elif isinstance(manual_bit_width_selection_list, ManualWeightsBitWidthSelection):
166+
unit_nodes_to_change_bit_width.update({n: [manual_bit_width_selection.bit_width, manual_bit_width_selection.attr]})
167+
168+
return unit_nodes_to_change_bit_width

tests_pytest/common_tests/core/common/quantization/test_manual_weights_bitwidth_selection.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def get_test_graph():
5555
return graph
5656

5757
class TestBitWidthConfig:
58-
59-
#######################################################################################################
60-
### test case
58+
# test case
6159
setter_test_input_0 = {"activation": (None, None),
6260
"weights": (None, None, None)}
6361
setter_test_input_1 = {"activation": (NodeTypeFilter(ReLU), [16]),
@@ -89,7 +87,7 @@ class TestBitWidthConfig:
8987
(setter_test_input_3, setter_test_expected_3),
9088
(setter_test_input_4, setter_test_expected_4),
9189
])
92-
def test_BitWidthConfig_setter(self, inputs, expected):
90+
def test_bit_width_config_setter(self, inputs, expected):
9391

9492
def check_param(mb_cfg, exp):
9593
### check setting config class (expected ManualBitWidthSelection)
@@ -157,7 +155,6 @@ def check_param_for_weights(mb_cfg, exp):
157155
check_param_for_weights(w_mb_cfg, weights_expected[idx])
158156

159157

160-
#######################################################################################################
161158
### test case
162159
### Note: setter inputs reuse getters test inputs
163160
getter_test_expected_0 = {"activation":{},
@@ -179,7 +176,7 @@ def check_param_for_weights(mb_cfg, exp):
179176
(setter_test_input_3, getter_test_expected_3),
180177
(setter_test_input_4, getter_test_expected_4),
181178
])
182-
def test_BitWidthConfig_getter(self, inputs, expected):
179+
def test_bit_width_config_getter(self, inputs, expected):
183180

184181
graph = get_test_graph()
185182

@@ -195,19 +192,20 @@ def test_BitWidthConfig_getter(self, inputs, expected):
195192
if weights[0] is not None:
196193
manual_bit_cfg.set_manual_weights_bit_width(weights[0], weights[1], weights[2])
197194

198-
get_manual_bit_dict = manual_bit_cfg.get_nodes_to_manipulate_bit_widths(graph)
195+
get_manual_bit_dict_activation = manual_bit_cfg.get_nodes_to_manipulate_activation_bit_widths(graph)
196+
get_manual_bit_dict_weights = manual_bit_cfg.get_nodes_to_manipulate_weights_bit_widths(graph)
199197

200198
if activation[0] is not None:
201-
for idx, (key, val) in enumerate(get_manual_bit_dict.activation_nodes_to_change_bit_width.items()):
199+
for idx, (key, val) in enumerate(get_manual_bit_dict_activation.items()):
202200
assert str(key) == list(activation_expected.keys())[idx]
203201
assert val == list(activation_expected.values())[idx]
204202
else:
205-
assert get_manual_bit_dict.activation_nodes_to_change_bit_width == activation_expected
203+
assert get_manual_bit_dict_activation == activation_expected
206204

207205
if weights[0] is not None:
208-
for idx, (key, val) in enumerate(get_manual_bit_dict.weights_nodes_to_change_bit_width.items()):
206+
for idx, (key, val) in enumerate(get_manual_bit_dict_weights.items()):
209207
assert str(key) == list(weights_expected.keys())[idx]
210208
assert val[0] == list(weights_expected.values())[idx][0]
211209
assert val[1] == list(weights_expected.values())[idx][1]
212210
else:
213-
assert get_manual_bit_dict.weights_nodes_to_change_bit_width == weights_expected
211+
assert get_manual_bit_dict_weights == weights_expected

0 commit comments

Comments
 (0)