Skip to content

Commit 7ba5a17

Browse files
fixing for manual weights selection bitwidth
1 parent 1af4e11 commit 7ba5a17

File tree

3 files changed

+279
-239
lines changed

3 files changed

+279
-239
lines changed

model_compression_toolkit/core/common/quantization/bit_width_config.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
103103
Returns:
104104
Dict: A dictionary mapping nodes to their new bit-widths.
105105
"""
106-
activation_nodes_to_change_bit_width = self._construct_node_to_new_bit_mapping(graph, self.manual_activation_bit_width_selection_list)
106+
activation_nodes_to_change_bit_width = self._construct_node_to_new_activation_bit_mapping(graph)
107107
return activation_nodes_to_change_bit_width
108108

109109
def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
@@ -116,7 +116,7 @@ def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
116116
Returns:
117117
Dict: A dictionary mapping nodes to their new bit-widths.
118118
"""
119-
weights_nodes_to_change_bit_width = self._construct_node_to_new_bit_mapping(graph, self.manual_weights_bit_width_selection_list)
119+
weights_nodes_to_change_bit_width = self._construct_node_to_new_weights_bit_mapping(graph)
120120
return weights_nodes_to_change_bit_width
121121

122122
@staticmethod
@@ -164,10 +164,9 @@ def _expand_to_list(
164164
attrs = BitWidthConfig._expand_to_list_core(filters, attrs)
165165
return attrs, bit_widths, filters
166166

167-
@staticmethod
168-
def _construct_node_to_new_bit_mapping(graph, manual_bit_width_selection_list) -> Dict:
167+
def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict:
169168
"""
170-
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
169+
Retrieve nodes from the graph that need their activation bit-widths changed according to the manual bit-width selections.
171170
172171
Args:
173172
graph (Graph): The graph containing the nodes to be filtered and manipulated.
@@ -176,30 +175,61 @@ def _construct_node_to_new_bit_mapping(graph, manual_bit_width_selection_list) -
176175
Dict: A dictionary retrieved nodes from the graph.
177176
"""
178177
unit_nodes_to_change_bit_width = {}
179-
for manual_bit_width_selection in manual_bit_width_selection_list:
178+
for manual_bit_width_selection in self.manual_activation_bit_width_selection_list:
180179
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
181180
if len(filtered_nodes) == 0:
182181
Logger.critical(
183182
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
184183
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
185184
for n in filtered_nodes:
186-
if type(manual_bit_width_selection) is ManualBitWidthSelection:
187-
# check if a manual configuration exists for this node
188-
if n in unit_nodes_to_change_bit_width:
189-
Logger.info(
190-
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}."
191-
f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
192-
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
193-
elif type(manual_bit_width_selection) is ManualWeightsBitWidthSelection:
194-
if len(n.get_node_weights_attributes()) == 0:
195-
Logger.critical(f'The requested attribute to change the bit width for {n} does not exist.')
196-
197-
if n in unit_nodes_to_change_bit_width:
198-
if unit_nodes_to_change_bit_width[n][1] == manual_bit_width_selection.attr:
185+
# check if a manual configuration exists for this node
186+
if n in unit_nodes_to_change_bit_width:
187+
Logger.info(
188+
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}."
189+
f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
190+
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
191+
return unit_nodes_to_change_bit_width
192+
193+
def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict:
194+
"""
195+
Retrieve nodes from the graph that need their weights bit-widths changed according to the manual bit-width selections.
196+
197+
Args:
198+
graph (Graph): The graph containing the nodes to be filtered and manipulated.
199+
200+
Returns:
201+
Dict: A dictionary retrieved nodes from the graph.
202+
"""
203+
unit_nodes_to_change_bit_width = {}
204+
205+
for manual_bit_width_selection in self.manual_weights_bit_width_selection_list:
206+
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
207+
if len(filtered_nodes) == 0:
208+
Logger.critical(
209+
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
210+
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
211+
212+
for n in filtered_nodes:
213+
attr_to_change_bit_width = []
214+
215+
attrs_str = n.get_node_weights_attributes()
216+
if len(attrs_str) == 0:
217+
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')
218+
219+
attr = [attr_str for attr_str in attrs_str if attr_str.find(manual_bit_width_selection.attr) != -1]
220+
if len(attr) == 0:
221+
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')
222+
223+
if n in unit_nodes_to_change_bit_width:
224+
attr_to_change_bit_width = unit_nodes_to_change_bit_width[n]
225+
for i, attr_to_bitwidth in enumerate(attr_to_change_bit_width):
226+
if attr_to_bitwidth[1] == manual_bit_width_selection.attr:
227+
del attr_to_change_bit_width[i]
199228
Logger.info(
200-
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}."
229+
f"Node {n} has an existing manual bit width configuration of {manual_bit_width_selection.attr}."
201230
f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
202-
unit_nodes_to_change_bit_width.update({n: [manual_bit_width_selection.bit_width, manual_bit_width_selection.attr]})
203-
else:
204-
Logger.critical(f'The type of manual_bit_width_selection_list must be ManualBitWidthSelection or ManualWeightsBitWidthSelection.')
231+
232+
attr_to_change_bit_width.append([manual_bit_width_selection.bit_width, manual_bit_width_selection.attr])
233+
unit_nodes_to_change_bit_width.update({n: attr_to_change_bit_width})
234+
205235
return unit_nodes_to_change_bit_width

model_compression_toolkit/core/common/quantization/set_node_quantization_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def set_quantization_configuration_to_graph(graph: Graph,
6565
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
6666
"Note: This method may significantly increase runtime during the parameter search process.")
6767

68-
nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph)
68+
nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
6969

7070
for n in graph.nodes:
7171
set_quantization_configs_to_node(node=n,

0 commit comments

Comments
 (0)