Skip to content

Commit 1f30c99

Browse files
committed
Correcting comments on pull requests
1 parent 5b88fe2 commit 1f30c99

File tree

2 files changed

+74
-36
lines changed

2 files changed

+74
-36
lines changed

model_compression_toolkit/core/common/quantization/bit_width_config.py

Lines changed: 74 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from dataclasses import dataclass, field
1616
from typing import List, Union, Dict
1717

18-
from model_compression_toolkit.constants import WEIGHTS_ATTRIBUTE, ACTIVATION_ATTRIBUTE
1918
from model_compression_toolkit.core.common import Graph
2019
from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
2120
from model_compression_toolkit.logger import Logger
@@ -25,25 +24,25 @@
2524
@dataclass
2625
class ManualBitWidthSelection:
2726
"""
28-
Class to encapsulate the manual bit width selection configuration for a specific filter.
27+
Class to encapsulate the manual bit width selection configuration for a specific filter.
2928
30-
Attributes:
29+
Attributes:
3130
filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
3231
bit_width (int): The bit width to be applied to the selected nodes.
33-
"""
32+
"""
3433
filter: BaseNodeMatcher
3534
bit_width: int
3635

3736
@dataclass
3837
class ManualWeightsBitWidthSelection(ManualBitWidthSelection):
3938
"""
40-
Class to encapsulate the manual weights bit width selection configuration for a specific filter.
39+
Class to encapsulate the manual weights bit width selection configuration for a specific filter.
4140
42-
Attributes:
41+
Attributes:
4342
filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
4443
bit_width (int): The bit width to be applied to the selected nodes.
4544
attr (str): The filtered node's attributes to apply bit-width manipulation to.
46-
"""
45+
"""
4746
attr: WeightAttrT
4847

4948
@dataclass
@@ -53,7 +52,7 @@ class BitWidthConfig:
5352
5453
Attributes:
5554
manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects for activation defining manual bit-width configurations.
56-
manual_activation_bit_width_selection_list (List[ManualWeightsBitWidthSelection]): A list of ManualWeightsBitWidthSelection for weights objects defining manual bit-width configurations.
55+
manual_weights_bit_width_selection_list (List[ManualWeightsBitWidthSelection]): A list of ManualWeightsBitWidthSelection for weights objects defining manual bit-width configurations.
5756
"""
5857
manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = field(default_factory=list)
5958
manual_weights_bit_width_selection_list: List[ManualWeightsBitWidthSelection] = field(default_factory=list)
@@ -69,31 +68,35 @@ def set_manual_activation_bit_width(self,
6968
bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
7069
If a single value is given it will be applied to all the filters
7170
"""
72-
_, bit_widths, filters = self._expand_to_list_filter_and_bit_width(filters, bit_widths)
71+
if filters is None:
72+
Logger.critical(f"The filters cannot be None.")
73+
_, bit_widths, filters = self._expand_to_list(filters, bit_widths)
7374
for bit_width, filter in zip (bit_widths, filters):
7475
self.manual_activation_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]
7576

7677
def set_manual_weights_bit_width(self,
7778
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
7879
bit_widths: Union[List[int], int],
79-
attrs: Union[List[str], str]
80+
attrs: Union[List[WeightAttrT], WeightAttrT]
8081
):
8182
"""
8283
Add a manual bit-width selection for weights to the configuration.
8384
8485
Args:
8586
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
8687
bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes.
87-
attrs (Union[List[str], str]): The filtered node's attributes to apply bit-width manipulation to.
88+
attrs (Union[List[WeightAttrT], WeightAttrT]): The filtered node's attributes to apply bit-width manipulation to.
8889
If a single value is given it will be applied to all the filters
8990
"""
90-
attrs, bit_widths, filters = self._expand_to_list_filter_and_bit_width(filters, bit_widths, attrs)
91+
if filters is None:
92+
Logger.critical(f"The filters cannot be None.")
93+
attrs, bit_widths, filters = self._expand_to_list(filters, bit_widths, attrs)
9194
for attr, bit_width, filter in zip (attrs, bit_widths, filters):
9295
self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, bit_width, attr)]
9396

9497
def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
9598
"""
96-
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
99+
Retrieve nodes from the graph that need their bit-widths for activation changed according to the manual bit-width selections.
97100
98101
Args:
99102
graph (Graph): The graph containing the nodes to be filtered and manipulated.
@@ -102,12 +105,11 @@ def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
102105
Dict: A dictionary mapping nodes to their new bit-widths.
103106
"""
104107
activation_nodes_to_change_bit_width = self._construct_node_to_new_bit_mapping(graph, self.manual_activation_bit_width_selection_list)
105-
106108
return activation_nodes_to_change_bit_width
107109

108110
def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
109111
"""
110-
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
112+
Retrieve nodes from the graph that need their bit-widths for weights changed according to the manual bit-width selections.
111113
112114
Args:
113115
graph (Graph): The graph containing the nodes to be filtered and manipulated.
@@ -116,15 +118,22 @@ def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
116118
Dict: A dictionary mapping nodes to their new bit-widths.
117119
"""
118120
weights_nodes_to_change_bit_width = self._construct_node_to_new_bit_mapping(graph, self.manual_weights_bit_width_selection_list)
119-
120121
return weights_nodes_to_change_bit_width
121122

123+
@staticmethod
124+
def _expand_to_list_core(
125+
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
126+
vals: Union[List[Union[WeightAttrT, int]], Union[WeightAttrT, int]]) -> list:
127+
"""
128+
Extend the length of vals to match the length of filters.
122129
130+
Args:
131+
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
132+
vals Union[List[Union[WeightAttrT, int], Union[WeightAttrT, int]]]): The bit widths or The filtered node's attributes.
123133
124-
def _expand_to_list_core(
125-
self,
126-
filters: Union[List[BaseNodeMatcher]],
127-
vals: Union[List[any], any]):
134+
Returns:
135+
list: Extended vals to match the length of filters.
136+
"""
128137
vals = [vals] if not isinstance(vals, list) else vals
129138
if len(vals) > 1 and len(vals) != len(filters):
130139
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(vals)} "
@@ -134,18 +143,39 @@ def _expand_to_list_core(
134143
vals = [vals[0] for f in filters]
135144
return vals
136145

137-
def _expand_to_list_filter_and_bit_width(
138-
self,
146+
@staticmethod
147+
def _expand_to_list(
139148
filters: Union[List[BaseNodeMatcher]],
140149
bit_widths: Union[List[int], int],
141-
attrs: Union[List[str], str] = None):
142-
filters = [filters] if not isinstance(filters, list) else filters
143-
bit_widths = self._expand_to_list_core(filters, bit_widths)
144-
attrs = self._expand_to_list_core(filters, attrs)
150+
attrs: Union[List[WeightAttrT], WeightAttrT] = None) -> [List]:
151+
"""
152+
Extend the length of filters, bit-widths and The filtered node's attributes to match the length of filters.
153+
154+
Args:
155+
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
156+
bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes.
157+
attrs (Union[List[WeightAttrT], WeightAttrT]): The filtered node's attributes to apply bit-width manipulation to.
145158
159+
Returns:
160+
[List]: A List of extended input arguments.
161+
"""
162+
filters = [filters] if not isinstance(filters, list) else filters
163+
bit_widths = BitWidthConfig._expand_to_list_core(filters, bit_widths)
164+
if attrs is not None:
165+
attrs = BitWidthConfig._expand_to_list_core(filters, attrs)
146166
return attrs, bit_widths, filters
147167

148-
def _construct_node_to_new_bit_mapping(self, graph, manual_bit_width_selection_list):
168+
@staticmethod
169+
def _construct_node_to_new_bit_mapping(graph, manual_bit_width_selection_list) -> Dict:
170+
"""
171+
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
172+
173+
Args:
174+
graph (Graph): The graph containing the nodes to be filtered and manipulated.
175+
176+
Returns:
177+
Dict: A dictionary retrieved nodes from the graph.
178+
"""
149179
unit_nodes_to_change_bit_width = {}
150180
for manual_bit_width_selection in manual_bit_width_selection_list:
151181
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
@@ -154,15 +184,23 @@ def _construct_node_to_new_bit_mapping(self, graph, manual_bit_width_selection_l
154184
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
155185
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
156186
for n in filtered_nodes:
157-
if n.get_node_weights_attributes() is False:
158-
Logger.critical(f'The requested attribute to change the bit width for {n} is not existing.')
159-
# check if a manual configuration exists for this node
160-
if n in unit_nodes_to_change_bit_width:
161-
Logger.info(
162-
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
163-
if isinstance(manual_bit_width_selection_list, ManualBitWidthSelection):
187+
if type(manual_bit_width_selection) is ManualBitWidthSelection:
188+
# check if a manual configuration exists for this node
189+
if n in unit_nodes_to_change_bit_width:
190+
Logger.info(
191+
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}."
192+
f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
164193
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
165-
elif isinstance(manual_bit_width_selection_list, ManualWeightsBitWidthSelection):
194+
elif type(manual_bit_width_selection) is ManualWeightsBitWidthSelection:
195+
if len(n.get_node_weights_attributes()) == 0:
196+
Logger.critical(f'The requested attribute to change the bit width for {n} is not existing.')
197+
198+
if n in unit_nodes_to_change_bit_width:
199+
if unit_nodes_to_change_bit_width[n][1] == manual_bit_width_selection.attr:
200+
Logger.info(
201+
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}."
202+
f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
166203
unit_nodes_to_change_bit_width.update({n: [manual_bit_width_selection.bit_width, manual_bit_width_selection.attr]})
167-
204+
else:
205+
Logger.critical(f'The type of manual_bit_width_selection_list must be ManualBitWidthSelection or ManualWeightsBitWidthSelection.')
168206
return unit_nodes_to_change_bit_width

tests_pytest/common_tests/core/common/quantization/test_manual_activation_and_weights_bitwidth_selection.py renamed to tests_pytest/common_tests/core/common/quantization/test_manual_bitwidth_selection.py

File renamed without changes.

0 commit comments

Comments
 (0)