@@ -31,7 +31,7 @@ class ManualBitWidthSelection:
3131 bit_width (int): The bit width to be applied to the selected nodes.
3232 """
3333 filter : BaseNodeMatcher
34- bit_width : int = 0
34+ bit_width : int
3535
3636@dataclass
3737class ManualWeightsBitWidthSelection (ManualBitWidthSelection ):
@@ -40,26 +40,31 @@ class ManualWeightsBitWidthSelection(ManualBitWidthSelection):
4040
4141 Attributes:
4242 filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
43- bit_width (int ): The bit width to be applied to the selected nodes.
43+ attr (str ): The attribute used to select nodes for bit width manipulation. [KERNEL_ATTR|BIAS_ATTR]
4444 """
45- kernel_bit_width : int = 0
46- bias_bit_width : int = 0
45+ attr : str
4746
48- def _expand_to_list_filter_and_bit_width (
47+ def _expand_to_list_core (
4948 filters : Union [List [BaseNodeMatcher ]],
50- bit_widths : Union [List [int ], int ]
51- ):
52- filters = [filters ] if not isinstance (filters , list ) else filters
53- bit_widths = [bit_widths ] if not isinstance (bit_widths , list ) else bit_widths
54- if len (bit_widths ) > 1 and len (bit_widths ) != len (filters ):
55- Logger .critical (f"Configuration Error: The number of provided bit_width values { len (bit_widths )} "
49+ vals : Union [List [any ], any ]):
50+ vals = [vals ] if not isinstance (vals , list ) else vals
51+ if len (vals ) > 1 and len (vals ) != len (filters ):
52+ Logger .critical (f"Configuration Error: The number of provided bit_width values { len (vals )} "
5653 f"must match the number of filters { len (filters )} , or a single bit_width value "
5754 f"should be provided for all filters." )
58- elif len (bit_widths ) == 1 and len (filters ) > 1 :
59- bit_widths = [bit_widths [0 ] for f in filters ]
55+ elif len (vals ) == 1 and len (filters ) > 1 :
56+ vals = [vals [0 ] for f in filters ]
57+ return vals
6058
61- return bit_widths , filters
59+ def _expand_to_list_filter_and_bit_width (
60+ filters : Union [List [BaseNodeMatcher ]],
61+ bit_widths : Union [List [int ], int ],
62+ attrs : Union [List [str ], str ] = None ):
63+ filters = [filters ] if not isinstance (filters , list ) else filters
64+ bit_widths = _expand_to_list_core (filters , bit_widths )
65+ attrs = _expand_to_list_core (filters , attrs )
6266
67+ return attrs , bit_widths , filters
6368
6469def _make_nodes_to_change_bit_width (graph , manual_bit_width_selection_list ):
6570 unit_nodes_to_change_bit_width = {}
@@ -77,7 +82,7 @@ def _make_nodes_to_change_bit_width(graph, manual_bit_width_selection_list):
7782 if isinstance (manual_bit_width_selection_list , ManualBitWidthSelection ):
7883 unit_nodes_to_change_bit_width .update ({n : manual_bit_width_selection .bit_width })
7984 elif isinstance (manual_bit_width_selection_list , ManualWeightsBitWidthSelection ):
80- unit_nodes_to_change_bit_width .update ({n : [manual_bit_width_selection .kernel_bit_width , manual_bit_width_selection .bias_bit_width ]})
85+ unit_nodes_to_change_bit_width .update ({n : [manual_bit_width_selection .bit_width , manual_bit_width_selection .attr ]})
8186
8287
8388 return unit_nodes_to_change_bit_width
@@ -109,30 +114,27 @@ def set_manual_activation_bit_width(self,
109114 bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
110115 If a single value is given it will be applied to all the filters
111116 """
112- bit_widths , filters = _expand_to_list_filter_and_bit_width (filters , bit_widths )
117+ _ , bit_widths , filters = _expand_to_list_filter_and_bit_width (filters , bit_widths )
113118 for bit_width , filter in zip (bit_widths , filters ):
114119 self .manual_activation_bit_width_selection_list += [ManualBitWidthSelection (filter , bit_width )]
115120
116121 def set_manual_weights_bit_width (self ,
117122 filters : Union [List [BaseNodeMatcher ], BaseNodeMatcher ],
118- kernel_bit_widths : Union [List [int ], int ],
119- bias_bit_widths : Union [List [int ], int ]
123+ bit_widths : Union [List [int ], int ],
124+ attrs : Union [List [str ], str ]
120125 ):
121126 """
122127 Add a manual bit-width selection for weights to the configuration.
123128
124129 Args:
125130 filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
126- kernel_bit_widths (Union[List[int], int]): The bit widths for kernel to be applied to the selected nodes.
127- bias_bit_widths (Union[List[int ], int ]): The bit widths for bias to be applied to the selected nodes.
131+ bit_widths (Union[List[int], int]): The bit widths for kernel to be applied to the selected nodes.
132+ attrs (Union[List[str ], str ]): The attributes used to select nodes for bit width manipulation. [KERNEL_ATTR|BIAS_ATTR]
128133 If a single value is given it will be applied to all the filters
129134 """
130- kernel_bit_widths , filters = _expand_to_list_filter_and_bit_width (filters , kernel_bit_widths )
131- bias_bit_widths , filters = _expand_to_list_filter_and_bit_width (filters , bias_bit_widths )
132- print ("kernel_bit_widths" , kernel_bit_widths )
133- print ("bias_bit_widths" , bias_bit_widths )
134- for kernel_bit_width , bias_bit_width , filter in zip (kernel_bit_widths , bias_bit_widths , filters ):
135- self .manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection (filter , kernel_bit_width = kernel_bit_width , bias_bit_width = bias_bit_width )]
135+ attrs , bit_widths , filters = _expand_to_list_filter_and_bit_width (filters , bit_widths , attrs )
136+ for attr , bit_width , filter in zip (attrs , bit_widths , filters ):
137+ self .manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection (filter , bit_width , attr )]
136138
137139 def get_nodes_to_manipulate_bit_widths (self , graph : Graph ) -> NodesToChangeBitWidth :
138140 """
0 commit comments