1515from dataclasses import dataclass , field
1616from typing import List , Union , Dict
1717
18- from model_compression_toolkit .constants import WEIGHTS_ATTRIBUTE , ACTIVATION_ATTRIBUTE
1918from model_compression_toolkit .core .common import Graph
2019from model_compression_toolkit .core .common .matchers .node_matcher import BaseNodeMatcher
2120from model_compression_toolkit .logger import Logger
2524@dataclass
2625class ManualBitWidthSelection :
2726 """
28- Class to encapsulate the manual bit width selection configuration for a specific filter.
27+ Class to encapsulate the manual bit width selection configuration for a specific filter.
2928
30- Attributes:
29+ Attributes:
3130 filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
3231 bit_width (int): The bit width to be applied to the selected nodes.
33- """
32+ """
3433 filter : BaseNodeMatcher
3534 bit_width : int
3635
3736@dataclass
3837class ManualWeightsBitWidthSelection (ManualBitWidthSelection ):
3938 """
40- Class to encapsulate the manual weights bit width selection configuration for a specific filter.
39+ Class to encapsulate the manual weights bit width selection configuration for a specific filter.
4140
42- Attributes:
41+ Attributes:
4342 filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
4443 bit_width (int): The bit width to be applied to the selected nodes.
4544 attr (str): The filtered node's attributes to apply bit-width manipulation to.
46- """
45+ """
4746 attr : WeightAttrT
4847
4948@dataclass
@@ -53,7 +52,7 @@ class BitWidthConfig:
5352
5453 Attributes:
5554 manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects for activation defining manual bit-width configurations.
56- manual_activation_bit_width_selection_list (List[ManualWeightsBitWidthSelection]): A list of ManualWeightsBitWidthSelection for weights objects defining manual bit-width configurations.
55+ manual_weights_bit_width_selection_list (List[ManualWeightsBitWidthSelection]): A list of ManualWeightsBitWidthSelection for weights objects defining manual bit-width configurations.
5756 """
5857 manual_activation_bit_width_selection_list : List [ManualBitWidthSelection ] = field (default_factory = list )
5958 manual_weights_bit_width_selection_list : List [ManualWeightsBitWidthSelection ] = field (default_factory = list )
@@ -69,31 +68,35 @@ def set_manual_activation_bit_width(self,
6968 bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
7069 If a single value is given it will be applied to all the filters
7170 """
72- _ , bit_widths , filters = self ._expand_to_list_filter_and_bit_width (filters , bit_widths )
71+ if filters is None :
72+ Logger .critical (f"The filters cannot be None." )
73+ _ , bit_widths , filters = self ._expand_to_list (filters , bit_widths )
7374 for bit_width , filter in zip (bit_widths , filters ):
7475 self .manual_activation_bit_width_selection_list += [ManualBitWidthSelection (filter , bit_width )]
7576
7677 def set_manual_weights_bit_width (self ,
7778 filters : Union [List [BaseNodeMatcher ], BaseNodeMatcher ],
7879 bit_widths : Union [List [int ], int ],
79- attrs : Union [List [str ], str ]
80+ attrs : Union [List [WeightAttrT ], WeightAttrT ]
8081 ):
8182 """
8283 Add a manual bit-width selection for weights to the configuration.
8384
8485 Args:
8586 filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
8687 bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes.
87- attrs (Union[List[str ], str ]): The filtered node's attributes to apply bit-width manipulation to.
88+ attrs (Union[List[WeightAttrT ], WeightAttrT ]): The filtered node's attributes to apply bit-width manipulation to.
8889 If a single value is given it will be applied to all the filters
8990 """
90- attrs , bit_widths , filters = self ._expand_to_list_filter_and_bit_width (filters , bit_widths , attrs )
91+ if filters is None :
92+ Logger .critical (f"The filters cannot be None." )
93+ attrs , bit_widths , filters = self ._expand_to_list (filters , bit_widths , attrs )
9194 for attr , bit_width , filter in zip (attrs , bit_widths , filters ):
9295 self .manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection (filter , bit_width , attr )]
9396
9497 def get_nodes_to_manipulate_activation_bit_widths (self , graph : Graph ) -> Dict :
9598 """
96- Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
99+ Retrieve nodes from the graph that need their bit-widths for activation changed according to the manual bit-width selections.
97100
98101 Args:
99102 graph (Graph): The graph containing the nodes to be filtered and manipulated.
@@ -102,12 +105,11 @@ def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
102105 Dict: A dictionary mapping nodes to their new bit-widths.
103106 """
104107 activation_nodes_to_change_bit_width = self ._construct_node_to_new_bit_mapping (graph , self .manual_activation_bit_width_selection_list )
105-
106108 return activation_nodes_to_change_bit_width
107109
108110 def get_nodes_to_manipulate_weights_bit_widths (self , graph : Graph ) -> Dict :
109111 """
110- Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
112+ Retrieve nodes from the graph that need their bit-widths for weights changed according to the manual bit-width selections.
111113
112114 Args:
113115 graph (Graph): The graph containing the nodes to be filtered and manipulated.
@@ -116,15 +118,22 @@ def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
116118 Dict: A dictionary mapping nodes to their new bit-widths.
117119 """
118120 weights_nodes_to_change_bit_width = self ._construct_node_to_new_bit_mapping (graph , self .manual_weights_bit_width_selection_list )
119-
120121 return weights_nodes_to_change_bit_width
121122
123+ @staticmethod
124+ def _expand_to_list_core (
125+ filters : Union [List [BaseNodeMatcher ], BaseNodeMatcher ],
126+ vals : Union [List [Union [WeightAttrT , int ]], Union [WeightAttrT , int ]]) -> list :
127+ """
128+ Extend the length of vals to match the length of filters.
122129
130+ Args:
131+ filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
132+ vals Union[List[Union[WeightAttrT, int], Union[WeightAttrT, int]]]): The bit widths or The filtered node's attributes.
123133
124- def _expand_to_list_core (
125- self ,
126- filters : Union [List [BaseNodeMatcher ]],
127- vals : Union [List [any ], any ]):
134+ Returns:
135+ list: Extended vals to match the length of filters.
136+ """
128137 vals = [vals ] if not isinstance (vals , list ) else vals
129138 if len (vals ) > 1 and len (vals ) != len (filters ):
130139 Logger .critical (f"Configuration Error: The number of provided bit_width values { len (vals )} "
@@ -134,18 +143,39 @@ def _expand_to_list_core(
134143 vals = [vals [0 ] for f in filters ]
135144 return vals
136145
137- def _expand_to_list_filter_and_bit_width (
138- self ,
146+ @ staticmethod
147+ def _expand_to_list (
139148 filters : Union [List [BaseNodeMatcher ]],
140149 bit_widths : Union [List [int ], int ],
141- attrs : Union [List [str ], str ] = None ):
142- filters = [filters ] if not isinstance (filters , list ) else filters
143- bit_widths = self ._expand_to_list_core (filters , bit_widths )
144- attrs = self ._expand_to_list_core (filters , attrs )
150+ attrs : Union [List [WeightAttrT ], WeightAttrT ] = None ) -> [List ]:
151+ """
152+ Extend the length of filters, bit-widths and The filtered node's attributes to match the length of filters.
153+
154+ Args:
155+ filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
156+ bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes.
157+ attrs (Union[List[WeightAttrT], WeightAttrT]): The filtered node's attributes to apply bit-width manipulation to.
145158
159+ Returns:
160+ [List]: A List of extended input arguments.
161+ """
162+ filters = [filters ] if not isinstance (filters , list ) else filters
163+ bit_widths = BitWidthConfig ._expand_to_list_core (filters , bit_widths )
164+ if attrs is not None :
165+ attrs = BitWidthConfig ._expand_to_list_core (filters , attrs )
146166 return attrs , bit_widths , filters
147167
148- def _construct_node_to_new_bit_mapping (self , graph , manual_bit_width_selection_list ):
168+ @staticmethod
169+ def _construct_node_to_new_bit_mapping (graph , manual_bit_width_selection_list ) -> Dict :
170+ """
171+ Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
172+
173+ Args:
174+ graph (Graph): The graph containing the nodes to be filtered and manipulated.
175+
176+ Returns:
177+ Dict: A dictionary retrieved nodes from the graph.
178+ """
149179 unit_nodes_to_change_bit_width = {}
150180 for manual_bit_width_selection in manual_bit_width_selection_list :
151181 filtered_nodes = graph .filter (manual_bit_width_selection .filter )
@@ -154,15 +184,23 @@ def _construct_node_to_new_bit_mapping(self, graph, manual_bit_width_selection_l
154184 f"Node Filtering Error: No nodes found in the graph for filter { manual_bit_width_selection .filter .__dict__ } "
155185 f"to change their bit width to { manual_bit_width_selection .bit_width } ." )
156186 for n in filtered_nodes :
157- if n .get_node_weights_attributes () is False :
158- Logger .critical (f'The requested attribute to change the bit width for { n } is not existing.' )
159- # check if a manual configuration exists for this node
160- if n in unit_nodes_to_change_bit_width :
161- Logger .info (
162- f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width .get (n )} . A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
163- if isinstance (manual_bit_width_selection_list , ManualBitWidthSelection ):
187+ if type (manual_bit_width_selection ) is ManualBitWidthSelection :
188+ # check if a manual configuration exists for this node
189+ if n in unit_nodes_to_change_bit_width :
190+ Logger .info (
191+ f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width .get (n )} ."
192+ f"A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
164193 unit_nodes_to_change_bit_width .update ({n : manual_bit_width_selection .bit_width })
165- elif isinstance (manual_bit_width_selection_list , ManualWeightsBitWidthSelection ):
194+ elif type (manual_bit_width_selection ) is ManualWeightsBitWidthSelection :
195+ if len (n .get_node_weights_attributes ()) == 0 :
196+ Logger .critical (f'The requested attribute to change the bit width for { n } is not existing.' )
197+
198+ if n in unit_nodes_to_change_bit_width :
199+ if unit_nodes_to_change_bit_width [n ][1 ] == manual_bit_width_selection .attr :
200+ Logger .info (
201+ f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width .get (n )} ."
202+ f"A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
166203 unit_nodes_to_change_bit_width .update ({n : [manual_bit_width_selection .bit_width , manual_bit_width_selection .attr ]})
167-
204+ else :
205+ Logger .critical (f'The type of manual_bit_width_selection_list must be ManualBitWidthSelection or ManualWeightsBitWidthSelection.' )
168206 return unit_nodes_to_change_bit_width
0 commit comments