Skip to content

Commit b07fc5d

Browse files
committed
correcting accrding to the feedback comments
1 parent 424dde6 commit b07fc5d

File tree

3 files changed

+72
-52
lines changed

3 files changed

+72
-52
lines changed

model_compression_toolkit/core/common/quantization/bit_width_config.py

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,54 @@ class ManualBitWidthSelection:
3333
filter: BaseNodeMatcher
3434
bit_width: int
3535

36+
@dataclass
37+
class ManualWeightsBitWidthSelection(ManualBitWidthSelection):
38+
"""
39+
Class to encapsulate the manual weights bit width selection configuration for a specific filter.
40+
41+
Attributes:
42+
filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
43+
bit_width (int): The bit width to be applied to the selected nodes.
44+
"""
45+
val: int
46+
47+
48+
def _set_manual_bit_width(filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
49+
bit_widths: Union[List[int], int]):
50+
51+
filters = [filters] if not isinstance(filters, list) else filters
52+
bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths
53+
if len(bit_widths) > 1 and len(bit_widths) != len(filters):
54+
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(bit_widths)} "
55+
f"must match the number of filters {len(filters)}, or a single bit_width value "
56+
f"should be provided for all filters.")
57+
elif len(bit_widths) == 1 and len(filters) > 1:
58+
bit_widths = [bit_widths[0] for f in filters]
59+
60+
return bit_widths, filters
61+
62+
63+
def _make_nodes_to_change_bit_width(graph, manual_bit_width_selection_list):
64+
unit_nodes_to_change_bit_width = {}
65+
for manual_bit_width_selection in manual_bit_width_selection_list:
66+
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
67+
if len(filtered_nodes) == 0:
68+
Logger.critical(
69+
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
70+
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
71+
for n in filtered_nodes:
72+
# check if a manual configuration exists for this node
73+
if n in unit_nodes_to_change_bit_width:
74+
Logger.info(
75+
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
76+
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
77+
78+
return unit_nodes_to_change_bit_width
79+
80+
@dataclass
81+
class NodesToChangeBitWidth:
82+
activation_nodes_to_change_bit_width: Dict
83+
weights_nodes_to_change_bit_width: Dict
3684

3785
@dataclass
3886
class BitWidthConfig:
@@ -49,47 +97,33 @@ def set_manual_activation_bit_width(self,
4997
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
5098
bit_widths: Union[List[int], int]):
5199
"""
52-
Add a manual bit-width selection to the configuration.
100+
Add a manual bit-width selection for activation to the configuration.
53101
54102
Args:
55-
filter (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
56-
bit_width (Union[List[int], int]): The bit widths to be applied to the selected nodes.
103+
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
104+
bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
57105
If a single value is given it will be applied to all the filters
58106
"""
59-
filters = [filters] if not isinstance(filters, list) else filters
60-
bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths
61-
if len(bit_widths) > 1 and len(bit_widths) != len(filters):
62-
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(bit_widths)} "
63-
f"must match the number of filters {len(filters)}, or a single bit_width value "
64-
f"should be provided for all filters.")
65-
elif len(bit_widths) == 1 and len(filters) > 1:
66-
bit_widths = [bit_widths[0] for f in filters]
107+
bit_widths, filters = _set_manual_bit_width(filters, bit_widths)
67108
for bit_width, filter in zip (bit_widths, filters):
68109
self.manual_activation_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]
69110

70111
def set_manual_weights_bit_width(self,
71112
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
72113
bit_widths: Union[List[int], int]):
73114
"""
74-
Add a manual bit-width selection to the configuration.
115+
Add a manual bit-width selection for weights to the configuration.
75116
76117
Args:
77-
filter (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
78-
bit_width (Union[List[int], int]): The bit widths to be applied to the selected nodes.
118+
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
119+
bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
79120
If a single value is given it will be applied to all the filters
80121
"""
81-
filters = [filters] if not isinstance(filters, list) else filters
82-
bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths
83-
if len(bit_widths) > 1 and len(bit_widths) != len(filters):
84-
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(bit_widths)} "
85-
f"must match the number of filters {len(filters)}, or a single bit_width value "
86-
f"should be provided for all filters.")
87-
elif len(bit_widths) == 1 and len(filters) > 1:
88-
bit_widths = [bit_widths[0] for f in filters]
122+
bit_widths, filters = _set_manual_bit_width(filters, bit_widths)
89123
for bit_width, filter in zip (bit_widths, filters):
90124
self.manual_weights_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]
91125

92-
def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict:
126+
def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> NodesToChangeBitWidth:
93127
"""
94128
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
95129
@@ -99,25 +133,9 @@ def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict:
99133
Returns:
100134
Dict: A dictionary mapping nodes to their new bit-widths.
101135
"""
102-
def make_nodes_to_change_bit_width(manual_bit_width_selection_list):
103-
unit_nodes_to_change_bit_width = {}
104-
for manual_bit_width_selection in manual_bit_width_selection_list:
105-
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
106-
if len(filtered_nodes) == 0:
107-
Logger.critical(
108-
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
109-
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
110-
for n in filtered_nodes:
111-
# check if a manual configuration exists for this node
112-
if n in unit_nodes_to_change_bit_width:
113-
Logger.info(
114-
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
115-
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
116-
117-
return unit_nodes_to_change_bit_width
118-
119-
activation_nodes_to_change_bit_width = make_nodes_to_change_bit_width(self.manual_activation_bit_width_selection_list)
120-
weights_nodes_to_change_bit_width = make_nodes_to_change_bit_width(self.manual_weights_bit_width_selection_list)
121-
122-
nodes_to_change_bit_width = {ACTIVATION_ATTRIBUTE: activation_nodes_to_change_bit_width, WEIGHTS_ATTRIBUTE: weights_nodes_to_change_bit_width}
136+
activation_nodes_to_change_bit_width = _make_nodes_to_change_bit_width(graph, self.manual_activation_bit_width_selection_list)
137+
weights_nodes_to_change_bit_width = _make_nodes_to_change_bit_width(graph, self.manual_weights_bit_width_selection_list)
138+
139+
#nodes_to_change_bit_width = {ACTIVATION_ATTRIBUTE: activation_nodes_to_change_bit_width, WEIGHTS_ATTRIBUTE: weights_nodes_to_change_bit_width}
140+
nodes_to_change_bit_width = NodesToChangeBitWidth(activation_nodes_to_change_bit_width, weights_nodes_to_change_bit_width)
123141
return nodes_to_change_bit_width

model_compression_toolkit/core/common/quantization/set_node_quantization_config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from mct_quantizers.common.constants import WEIGHTS_N_BITS, ACTIVATION_N_BITS
2121
from model_compression_toolkit.constants import WEIGHTS_ATTRIBUTE, ACTIVATION_ATTRIBUTE
2222
from model_compression_toolkit.core.common import BaseNode
23-
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
23+
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig, NodesToChangeBitWidth
2424
from model_compression_toolkit.logger import Logger
2525
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
2626
from model_compression_toolkit.core.common.graph.base_graph import Graph
@@ -67,11 +67,13 @@ def set_quantization_configuration_to_graph(graph: Graph,
6767
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
6868
"Note: This method may significantly increase runtime during the parameter search process.")
6969

70-
nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph)
70+
nodes_to_manipulate_bit_widths = NodesToChangeBitWidth() if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph)
71+
print('nodes_to_manipulate_bit_widths', nodes_to_manipulate_bit_widths)
72+
print('zzz', nodes_to_manipulate_bit_widths.activation_nodes_to_change_bit_width)
7173

7274
for n in graph.nodes:
73-
manual_bit_width_override = {ACTIVATION_ATTRIBUTE: nodes_to_manipulate_bit_widths.get(ACTIVATION_ATTRIBUTE).get(n),
74-
WEIGHTS_ATTRIBUTE: nodes_to_manipulate_bit_widths.get(WEIGHTS_ATTRIBUTE).get(n)}
75+
manual_bit_width_override = {ACTIVATION_ATTRIBUTE: nodes_to_manipulate_bit_widths.activation_nodes_to_change_bit_width.get(n),
76+
WEIGHTS_ATTRIBUTE: nodes_to_manipulate_bit_widths.weights_nodes_to_change_bit_width.get(n)}
7577
set_quantization_configs_to_node(node=n,
7678
graph=graph,
7779
quant_config=quant_config,

tests_pytest/common_tests/core/common/quantization/test_manual_weights_bitwidth_selection.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,15 @@ def test_BitWidthConfig_getter(self, inputs, expected):
179179
get_manual_bit_dict = manual_bit_cfg.get_nodes_to_manipulate_bit_widths(graph)
180180

181181
if activation[0] is not None:
182-
for idx, (key, val) in enumerate(get_manual_bit_dict["activation"].items()):
182+
for idx, (key, val) in enumerate(get_manual_bit_dict.activation_nodes_to_change_bit_width.items()):
183183
assert str(key) == list(activation_expected.keys())[idx]
184184
assert val == list(activation_expected.values())[idx]
185185
else:
186-
assert get_manual_bit_dict["activation"] == activation_expected
186+
assert get_manual_bit_dict.activation_nodes_to_change_bit_width == activation_expected
187187

188188
if weights[0] is not None:
189-
for idx, (key, val) in enumerate(get_manual_bit_dict["weights"].items()):
189+
for idx, (key, val) in enumerate(get_manual_bit_dict.weights_nodes_to_change_bit_width.items()):
190190
assert str(key) == list(weights_expected.keys())[idx]
191191
assert val == list(weights_expected.values())[idx]
192192
else:
193-
assert get_manual_bit_dict["weights"] == weights_expected
193+
assert get_manual_bit_dict.weights_nodes_to_change_bit_width == weights_expected

0 commit comments

Comments
 (0)