@@ -103,7 +103,7 @@ def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
103103 Returns:
104104 Dict: A dictionary mapping nodes to their new bit-widths.
105105 """
106- activation_nodes_to_change_bit_width = self ._construct_node_to_new_bit_mapping (graph , self . manual_activation_bit_width_selection_list )
106+ activation_nodes_to_change_bit_width = self ._construct_node_to_new_activation_bit_mapping (graph )
107107 return activation_nodes_to_change_bit_width
108108
109109 def get_nodes_to_manipulate_weights_bit_widths (self , graph : Graph ) -> Dict :
@@ -116,7 +116,7 @@ def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
116116 Returns:
117117 Dict: A dictionary mapping nodes to their new bit-widths.
118118 """
119- weights_nodes_to_change_bit_width = self ._construct_node_to_new_bit_mapping (graph , self . manual_weights_bit_width_selection_list )
119+ weights_nodes_to_change_bit_width = self ._construct_node_to_new_weights_bit_mapping (graph )
120120 return weights_nodes_to_change_bit_width
121121
122122 @staticmethod
@@ -164,10 +164,9 @@ def _expand_to_list(
164164 attrs = BitWidthConfig ._expand_to_list_core (filters , attrs )
165165 return attrs , bit_widths , filters
166166
167- @staticmethod
168- def _construct_node_to_new_bit_mapping (graph , manual_bit_width_selection_list ) -> Dict :
167+ def _construct_node_to_new_activation_bit_mapping (self , graph ) -> Dict :
169168 """
170- Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
169+ Retrieve nodes from the graph that need their activation bit-widths changed according to the manual bit-width selections.
171170
172171 Args:
173172 graph (Graph): The graph containing the nodes to be filtered and manipulated.
@@ -176,30 +175,61 @@ def _construct_node_to_new_bit_mapping(graph, manual_bit_width_selection_list) -
176175 Dict: A dictionary retrieved nodes from the graph.
177176 """
178177 unit_nodes_to_change_bit_width = {}
179- for manual_bit_width_selection in manual_bit_width_selection_list :
178+ for manual_bit_width_selection in self . manual_activation_bit_width_selection_list :
180179 filtered_nodes = graph .filter (manual_bit_width_selection .filter )
181180 if len (filtered_nodes ) == 0 :
182181 Logger .critical (
183182 f"Node Filtering Error: No nodes found in the graph for filter { manual_bit_width_selection .filter .__dict__ } "
184183 f"to change their bit width to { manual_bit_width_selection .bit_width } ." )
185184 for n in filtered_nodes :
186- if type (manual_bit_width_selection ) is ManualBitWidthSelection :
187- # check if a manual configuration exists for this node
188- if n in unit_nodes_to_change_bit_width :
189- Logger .info (
190- f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width .get (n )} ."
191- f"A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
192- unit_nodes_to_change_bit_width .update ({n : manual_bit_width_selection .bit_width })
193- elif type (manual_bit_width_selection ) is ManualWeightsBitWidthSelection :
194- if len (n .get_node_weights_attributes ()) == 0 :
195- Logger .critical (f'The requested attribute to change the bit width for { n } does not exist.' )
196-
197- if n in unit_nodes_to_change_bit_width :
198- if unit_nodes_to_change_bit_width [n ][1 ] == manual_bit_width_selection .attr :
185+ # check if a manual configuration exists for this node
186+ if n in unit_nodes_to_change_bit_width :
187+ Logger .info (
188+ f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width .get (n )} ."
189+ f"A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
190+ unit_nodes_to_change_bit_width .update ({n : manual_bit_width_selection .bit_width })
191+ return unit_nodes_to_change_bit_width
192+
193+ def _construct_node_to_new_weights_bit_mapping (self , graph ) -> Dict :
194+ """
195+ Retrieve nodes from the graph that need their weights bit-widths changed according to the manual bit-width selections.
196+
197+ Args:
198+ graph (Graph): The graph containing the nodes to be filtered and manipulated.
199+
200+ Returns:
201+ Dict: A dictionary retrieved nodes from the graph.
202+ """
203+ unit_nodes_to_change_bit_width = {}
204+
205+ for manual_bit_width_selection in self .manual_weights_bit_width_selection_list :
206+ filtered_nodes = graph .filter (manual_bit_width_selection .filter )
207+ if len (filtered_nodes ) == 0 :
208+ Logger .critical (
209+ f"Node Filtering Error: No nodes found in the graph for filter { manual_bit_width_selection .filter .__dict__ } "
210+ f"to change their bit width to { manual_bit_width_selection .bit_width } ." )
211+
212+ for n in filtered_nodes :
213+ attr_to_change_bit_width = []
214+
215+ attrs_str = n .get_node_weights_attributes ()
216+ if len (attrs_str ) == 0 :
217+ Logger .critical (f'The requested attribute { manual_bit_width_selection .attr } to change the bit width for { n } does not exist.' )
218+
219+ attr = [attr_str for attr_str in attrs_str if attr_str .find (manual_bit_width_selection .attr ) != - 1 ]
220+ if len (attr ) == 0 :
221+ Logger .critical (f'The requested attribute { manual_bit_width_selection .attr } to change the bit width for { n } does not exist.' )
222+
223+ if n in unit_nodes_to_change_bit_width :
224+ attr_to_change_bit_width = unit_nodes_to_change_bit_width [n ]
225+ for i , attr_to_bitwidth in enumerate (attr_to_change_bit_width ):
226+ if attr_to_bitwidth [1 ] == manual_bit_width_selection .attr :
227+ del attr_to_change_bit_width [i ]
199228 Logger .info (
200- f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width . get ( n ) } ."
229+ f"Node { n } has an existing manual bit width configuration of { manual_bit_width_selection . attr } ."
201230 f"A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
202- unit_nodes_to_change_bit_width .update ({n : [manual_bit_width_selection .bit_width , manual_bit_width_selection .attr ]})
203- else :
204- Logger .critical (f'The type of manual_bit_width_selection_list must be ManualBitWidthSelection or ManualWeightsBitWidthSelection.' )
231+
232+ attr_to_change_bit_width .append ([manual_bit_width_selection .bit_width , manual_bit_width_selection .attr ])
233+ unit_nodes_to_change_bit_width .update ({n : attr_to_change_bit_width })
234+
205235 return unit_nodes_to_change_bit_width
0 commit comments