@@ -33,6 +33,54 @@ class ManualBitWidthSelection:
3333 filter : BaseNodeMatcher
3434 bit_width : int
3535
36+ @dataclass
37+ class ManualWeightsBitWidthSelection (ManualBitWidthSelection ):
38+ """
39+ Class to encapsulate the manual weights bit width selection configuration for a specific filter.
40+
41+ Attributes:
42+ filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
43+ bit_width (int): The bit width to be applied to the selected nodes.
44+ """
45+ val : int
46+
47+
48+ def _set_manual_bit_width (filters : Union [List [BaseNodeMatcher ], BaseNodeMatcher ],
49+ bit_widths : Union [List [int ], int ]):
50+
51+ filters = [filters ] if not isinstance (filters , list ) else filters
52+ bit_widths = [bit_widths ] if not isinstance (bit_widths , list ) else bit_widths
53+ if len (bit_widths ) > 1 and len (bit_widths ) != len (filters ):
54+ Logger .critical (f"Configuration Error: The number of provided bit_width values { len (bit_widths )} "
55+ f"must match the number of filters { len (filters )} , or a single bit_width value "
56+ f"should be provided for all filters." )
57+ elif len (bit_widths ) == 1 and len (filters ) > 1 :
58+ bit_widths = [bit_widths [0 ] for f in filters ]
59+
60+ return bit_widths , filters
61+
62+
63+ def _make_nodes_to_change_bit_width (graph , manual_bit_width_selection_list ):
64+ unit_nodes_to_change_bit_width = {}
65+ for manual_bit_width_selection in manual_bit_width_selection_list :
66+ filtered_nodes = graph .filter (manual_bit_width_selection .filter )
67+ if len (filtered_nodes ) == 0 :
68+ Logger .critical (
69+ f"Node Filtering Error: No nodes found in the graph for filter { manual_bit_width_selection .filter .__dict__ } "
70+ f"to change their bit width to { manual_bit_width_selection .bit_width } ." )
71+ for n in filtered_nodes :
72+ # check if a manual configuration exists for this node
73+ if n in unit_nodes_to_change_bit_width :
74+ Logger .info (
75+ f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width .get (n )} . A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
76+ unit_nodes_to_change_bit_width .update ({n : manual_bit_width_selection .bit_width })
77+
78+ return unit_nodes_to_change_bit_width
79+
80+ @dataclass
81+ class NodesToChangeBitWidth :
82+ activation_nodes_to_change_bit_width : Dict
83+ weights_nodes_to_change_bit_width : Dict
3684
3785@dataclass
3886class BitWidthConfig :
@@ -49,47 +97,33 @@ def set_manual_activation_bit_width(self,
4997 filters : Union [List [BaseNodeMatcher ], BaseNodeMatcher ],
5098 bit_widths : Union [List [int ], int ]):
5199 """
52- Add a manual bit-width selection to the configuration.
100+ Add a manual bit-width selection for activation to the configuration.
53101
54102 Args:
55- filter (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
56- bit_width (Union[List[int], int]): The bit widths to be applied to the selected nodes.
103+ filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
104+ bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
57105 If a single value is given it will be applied to all the filters
58106 """
59- filters = [filters ] if not isinstance (filters , list ) else filters
60- bit_widths = [bit_widths ] if not isinstance (bit_widths , list ) else bit_widths
61- if len (bit_widths ) > 1 and len (bit_widths ) != len (filters ):
62- Logger .critical (f"Configuration Error: The number of provided bit_width values { len (bit_widths )} "
63- f"must match the number of filters { len (filters )} , or a single bit_width value "
64- f"should be provided for all filters." )
65- elif len (bit_widths ) == 1 and len (filters ) > 1 :
66- bit_widths = [bit_widths [0 ] for f in filters ]
107+ bit_widths , filters = _set_manual_bit_width (filters , bit_widths )
67108 for bit_width , filter in zip (bit_widths , filters ):
68109 self .manual_activation_bit_width_selection_list += [ManualBitWidthSelection (filter , bit_width )]
69110
70111 def set_manual_weights_bit_width (self ,
71112 filters : Union [List [BaseNodeMatcher ], BaseNodeMatcher ],
72113 bit_widths : Union [List [int ], int ]):
73114 """
74- Add a manual bit-width selection to the configuration.
115+ Add a manual bit-width selection for weights to the configuration.
75116
76117 Args:
77- filter (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
78- bit_width (Union[List[int], int]): The bit widths to be applied to the selected nodes.
118+ filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
119+ bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
79120 If a single value is given it will be applied to all the filters
80121 """
81- filters = [filters ] if not isinstance (filters , list ) else filters
82- bit_widths = [bit_widths ] if not isinstance (bit_widths , list ) else bit_widths
83- if len (bit_widths ) > 1 and len (bit_widths ) != len (filters ):
84- Logger .critical (f"Configuration Error: The number of provided bit_width values { len (bit_widths )} "
85- f"must match the number of filters { len (filters )} , or a single bit_width value "
86- f"should be provided for all filters." )
87- elif len (bit_widths ) == 1 and len (filters ) > 1 :
88- bit_widths = [bit_widths [0 ] for f in filters ]
122+ bit_widths , filters = _set_manual_bit_width (filters , bit_widths )
89123 for bit_width , filter in zip (bit_widths , filters ):
90124 self .manual_weights_bit_width_selection_list += [ManualBitWidthSelection (filter , bit_width )]
91125
92- def get_nodes_to_manipulate_bit_widths (self , graph : Graph ) -> Dict :
126+ def get_nodes_to_manipulate_bit_widths (self , graph : Graph ) -> NodesToChangeBitWidth :
93127 """
94128 Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
95129
@@ -99,25 +133,9 @@ def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict:
99133 Returns:
100134 Dict: A dictionary mapping nodes to their new bit-widths.
101135 """
102- def make_nodes_to_change_bit_width (manual_bit_width_selection_list ):
103- unit_nodes_to_change_bit_width = {}
104- for manual_bit_width_selection in manual_bit_width_selection_list :
105- filtered_nodes = graph .filter (manual_bit_width_selection .filter )
106- if len (filtered_nodes ) == 0 :
107- Logger .critical (
108- f"Node Filtering Error: No nodes found in the graph for filter { manual_bit_width_selection .filter .__dict__ } "
109- f"to change their bit width to { manual_bit_width_selection .bit_width } ." )
110- for n in filtered_nodes :
111- # check if a manual configuration exists for this node
112- if n in unit_nodes_to_change_bit_width :
113- Logger .info (
114- f"Node { n } has an existing manual bit width configuration of { unit_nodes_to_change_bit_width .get (n )} . A new manual configuration request of { manual_bit_width_selection .bit_width } has been received, and the previous value is being overridden." )
115- unit_nodes_to_change_bit_width .update ({n : manual_bit_width_selection .bit_width })
116-
117- return unit_nodes_to_change_bit_width
118-
119- activation_nodes_to_change_bit_width = make_nodes_to_change_bit_width (self .manual_activation_bit_width_selection_list )
120- weights_nodes_to_change_bit_width = make_nodes_to_change_bit_width (self .manual_weights_bit_width_selection_list )
121-
122- nodes_to_change_bit_width = {ACTIVATION_ATTRIBUTE : activation_nodes_to_change_bit_width , WEIGHTS_ATTRIBUTE : weights_nodes_to_change_bit_width }
136+ activation_nodes_to_change_bit_width = _make_nodes_to_change_bit_width (graph , self .manual_activation_bit_width_selection_list )
137+ weights_nodes_to_change_bit_width = _make_nodes_to_change_bit_width (graph , self .manual_weights_bit_width_selection_list )
138+
139+ #nodes_to_change_bit_width = {ACTIVATION_ATTRIBUTE: activation_nodes_to_change_bit_width, WEIGHTS_ATTRIBUTE: weights_nodes_to_change_bit_width}
140+ nodes_to_change_bit_width = NodesToChangeBitWidth (activation_nodes_to_change_bit_width , weights_nodes_to_change_bit_width )
123141 return nodes_to_change_bit_width
0 commit comments