2020from model_compression_toolkit .core .common .matchers .node_matcher import BaseNodeMatcher
2121from model_compression_toolkit .logger import Logger
2222
23+ from model_compression_toolkit .core .common .graph .base_node import WeightAttrT
2324
2425@dataclass
2526class ManualBitWidthSelection :
@@ -40,68 +41,22 @@ class ManualWeightsBitWidthSelection(ManualBitWidthSelection):
4041
4142 Attributes:
4243 filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
43- attr (str): The attribute used to select nodes for bit width manipulation. [KERNEL_ATTR|BIAS_ATTR]
44+ bit_width (int): The bit width to be applied to the selected nodes.
45+ attr (str): The filtered node's attributes to apply bit-width manipulation to.
4446 """
45- attr : str
46-
47- def _expand_to_list_core (
48- filters : Union [List [BaseNodeMatcher ]],
49- vals : Union [List [any ], any ]):
50- vals = [vals ] if not isinstance (vals , list ) else vals
51- if len (vals ) > 1 and len (vals ) != len (filters ):
52- Logger .critical (f"Configuration Error: The number of provided bit_width values { len (vals )} "
53- f"must match the number of filters { len (filters )} , or a single bit_width value "
54- f"should be provided for all filters." )
55- elif len (vals ) == 1 and len (filters ) > 1 :
56- vals = [vals [0 ] for f in filters ]
57- return vals
58-
59- def _expand_to_list_filter_and_bit_width (
60- filters : Union [List [BaseNodeMatcher ]],
61- bit_widths : Union [List [int ], int ],
62- attrs : Union [List [str ], str ] = None ):
63- filters = [filters ] if not isinstance (filters , list ) else filters
64- bit_widths = _expand_to_list_core (filters , bit_widths )
65- attrs = _expand_to_list_core (filters , attrs )
66-
67- return attrs , bit_widths , filters
68-
69- def _make_nodes_to_change_bit_width (graph , manual_bit_width_selection_list ):
70- unit_nodes_to_change_bit_width = {}
71- for manual_bit_width_selection in manual_bit_width_selection_list :
72- filtered_nodes = graph .filter (manual_bit_width_selection .filter )
73- if len (filtered_nodes ) == 0 :
74- Logger .critical (
75- f"Node Filtering Error: No nodes found in the graph for filter { manual_bit_width_selection .filter .__dict__ } "
76- f"to change their bit width to { manual_bit_width_selection .bit_width } ." )
77- for n in filtered_nodes :
78- # check if a manual configuration exists for this node
79- if n in unit_nodes_to_change_bit_width :
80- Logger .info (
81- f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width .get (n )} . A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
82- if isinstance (manual_bit_width_selection_list , ManualBitWidthSelection ):
83- unit_nodes_to_change_bit_width .update ({n : manual_bit_width_selection .bit_width })
84- elif isinstance (manual_bit_width_selection_list , ManualWeightsBitWidthSelection ):
85- unit_nodes_to_change_bit_width .update ({n : [manual_bit_width_selection .bit_width , manual_bit_width_selection .attr ]})
86-
87-
88- return unit_nodes_to_change_bit_width
89-
90- @dataclass
91- class NodesToChangeBitWidth :
92- activation_nodes_to_change_bit_width : Dict
93- weights_nodes_to_change_bit_width : Dict
47+ attr : WeightAttrT
9448
9549@dataclass
9650class BitWidthConfig :
9751 """
9852 Class to manage manual bit-width configurations.
9953
10054 Attributes:
101- manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects defining manual bit-width configurations.
55+ manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects for activation defining manual bit-width configurations.
56+ manual_activation_bit_width_selection_list (List[ManualWeightsBitWidthSelection]): A list of ManualWeightsBitWidthSelection for weights objects defining manual bit-width configurations.
10257 """
10358 manual_activation_bit_width_selection_list : List [ManualBitWidthSelection ] = field (default_factory = list )
104- manual_weights_bit_width_selection_list : List [ManualBitWidthSelection ] = field (default_factory = list )
59+ manual_weights_bit_width_selection_list : List [ManualWeightsBitWidthSelection ] = field (default_factory = list )
10560
10661 def set_manual_activation_bit_width (self ,
10762 filters : Union [List [BaseNodeMatcher ], BaseNodeMatcher ],
@@ -114,7 +69,7 @@ def set_manual_activation_bit_width(self,
11469 bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
11570 If a single value is given it will be applied to all the filters
11671 """
117- _ , bit_widths , filters = _expand_to_list_filter_and_bit_width (filters , bit_widths )
72+ _ , bit_widths , filters = self . _expand_to_list_filter_and_bit_width (filters , bit_widths )
11873 for bit_width , filter in zip (bit_widths , filters ):
11974 self .manual_activation_bit_width_selection_list += [ManualBitWidthSelection (filter , bit_width )]
12075
@@ -128,15 +83,15 @@ def set_manual_weights_bit_width(self,
12883
12984 Args:
13085 filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
131- bit_widths (Union[List[int], int]): The bit widths for kernel to be applied to the selected nodes.
132- attrs (Union[List[str], str]): The attributes used to select nodes for bit width manipulation. [KERNEL_ATTR|BIAS_ATTR]
86+ bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes.
87+ attrs (Union[List[str], str]): The filtered node's attributes to apply bit- width manipulation to.
13388 If a single value is given it will be applied to all the filters
13489 """
135- attrs , bit_widths , filters = _expand_to_list_filter_and_bit_width (filters , bit_widths , attrs )
90+ attrs , bit_widths , filters = self . _expand_to_list_filter_and_bit_width (filters , bit_widths , attrs )
13691 for attr , bit_width , filter in zip (attrs , bit_widths , filters ):
13792 self .manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection (filter , bit_width , attr )]
13893
139- def get_nodes_to_manipulate_bit_widths (self , graph : Graph ) -> NodesToChangeBitWidth :
94+ def get_nodes_to_manipulate_activation_bit_widths (self , graph : Graph ) -> Dict :
14095 """
14196 Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
14297
@@ -146,8 +101,68 @@ def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> NodesToChangeBitWi
146101 Returns:
147102 Dict: A dictionary mapping nodes to their new bit-widths.
148103 """
149- activation_nodes_to_change_bit_width = _make_nodes_to_change_bit_width (graph , self .manual_activation_bit_width_selection_list )
150- weights_nodes_to_change_bit_width = _make_nodes_to_change_bit_width (graph , self .manual_weights_bit_width_selection_list )
104+ activation_nodes_to_change_bit_width = self ._construct_node_to_new_bit_mapping (graph , self .manual_activation_bit_width_selection_list )
105+
106+ return activation_nodes_to_change_bit_width
107+
108+ def get_nodes_to_manipulate_weights_bit_widths (self , graph : Graph ) -> Dict :
109+ """
110+ Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
151111
152- nodes_to_change_bit_width = NodesToChangeBitWidth (activation_nodes_to_change_bit_width , weights_nodes_to_change_bit_width )
153- return nodes_to_change_bit_width
112+ Args:
113+ graph (Graph): The graph containing the nodes to be filtered and manipulated.
114+
115+ Returns:
116+ Dict: A dictionary mapping nodes to their new bit-widths.
117+ """
118+ weights_nodes_to_change_bit_width = self ._construct_node_to_new_bit_mapping (graph , self .manual_weights_bit_width_selection_list )
119+
120+ return weights_nodes_to_change_bit_width
121+
122+
123+
124+ def _expand_to_list_core (
125+ self ,
126+ filters : Union [List [BaseNodeMatcher ]],
127+ vals : Union [List [any ], any ]):
128+ vals = [vals ] if not isinstance (vals , list ) else vals
129+ if len (vals ) > 1 and len (vals ) != len (filters ):
130+ Logger .critical (f"Configuration Error: The number of provided bit_width values { len (vals )} "
131+ f"must match the number of filters { len (filters )} , or a single bit_width value "
132+ f"should be provided for all filters." )
133+ elif len (vals ) == 1 and len (filters ) > 1 :
134+ vals = [vals [0 ] for f in filters ]
135+ return vals
136+
137+ def _expand_to_list_filter_and_bit_width (
138+ self ,
139+ filters : Union [List [BaseNodeMatcher ]],
140+ bit_widths : Union [List [int ], int ],
141+ attrs : Union [List [str ], str ] = None ):
142+ filters = [filters ] if not isinstance (filters , list ) else filters
143+ bit_widths = self ._expand_to_list_core (filters , bit_widths )
144+ attrs = self ._expand_to_list_core (filters , attrs )
145+
146+ return attrs , bit_widths , filters
147+
148+ def _construct_node_to_new_bit_mapping (self , graph , manual_bit_width_selection_list ):
149+ unit_nodes_to_change_bit_width = {}
150+ for manual_bit_width_selection in manual_bit_width_selection_list :
151+ filtered_nodes = graph .filter (manual_bit_width_selection .filter )
152+ if len (filtered_nodes ) == 0 :
153+ Logger .critical (
154+ f"Node Filtering Error: No nodes found in the graph for filter { manual_bit_width_selection .filter .__dict__ } "
155+ f"to change their bit width to { manual_bit_width_selection .bit_width } ." )
156+ for n in filtered_nodes :
157+ if n .get_node_weights_attributes () is False :
158+ Logger .critical (f'The requested attribute to change the bit width for { n } is not existing.' )
159+ # check if a manual configuration exists for this node
160+ if n in unit_nodes_to_change_bit_width :
161+ Logger .info (
162+ f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width .get (n )} . A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
163+ if isinstance (manual_bit_width_selection_list , ManualBitWidthSelection ):
164+ unit_nodes_to_change_bit_width .update ({n : manual_bit_width_selection .bit_width })
165+ elif isinstance (manual_bit_width_selection_list , ManualWeightsBitWidthSelection ):
166+ unit_nodes_to_change_bit_width .update ({n : [manual_bit_width_selection .bit_width , manual_bit_width_selection .attr ]})
167+
168+ return unit_nodes_to_change_bit_width
0 commit comments