1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414# ==============================================================================
15-
16-
1715import copy
18- from typing import List , Tuple , Optional
16+ from typing import List , Tuple , Dict , Optional
1917
20- from mct_quantizers .common .constants import ACTIVATION_N_BITS
18+ from mct_quantizers .common .constants import WEIGHTS_N_BITS , ACTIVATION_N_BITS
19+ from model_compression_toolkit .constants import WEIGHTS , ACTIVATION
2120from model_compression_toolkit .core .common import BaseNode
2221from model_compression_toolkit .core .common .quantization .bit_width_config import BitWidthConfig
2322from model_compression_toolkit .logger import Logger
3029from model_compression_toolkit .core .common .quantization .quantization_config import QuantizationConfig , \
3130 QuantizationErrorMethod
3231from model_compression_toolkit .core .common .quantization .quantization_params_fn_selection import \
33- get_activation_quantization_params_fn , get_weights_quantization_params_fn
34- from model_compression_toolkit .core .common .quantization .quantization_fn_selection import \
35- get_weights_quantization_fn
32+ get_activation_quantization_params_fn
3633from model_compression_toolkit .target_platform_capabilities .schema .schema_functions import max_input_activation_n_bits
3734from model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema import OpQuantizationConfig , \
3835 QuantizationConfigOptions
3936from model_compression_toolkit .target_platform_capabilities .targetplatform2framework .framework_quantization_capabilities import \
4037 FrameworkQuantizationCapabilities
38+ from model_compression_toolkit .core .common .graph .base_node import WeightAttrT
4139
4240
4341def set_quantization_configuration_to_graph (graph : Graph ,
@@ -66,16 +64,19 @@ def set_quantization_configuration_to_graph(graph: Graph,
6664 Logger .warning ("Using the HMSE error method for weights quantization parameters search. "
6765 "Note: This method may significantly increase runtime during the parameter search process." )
6866
69- nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config .get_nodes_to_manipulate_activation_bit_widths (graph )
67+ nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config .get_nodes_to_manipulate_activation_bit_widths (graph )
68+ nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config .get_nodes_to_manipulate_weights_bit_widths (graph )
7069
7170 for n in graph .nodes :
71+ manual_bit_width_override = {ACTIVATION : nodes_to_manipulate_activation_bit_widths .get (n ),
72+ WEIGHTS : nodes_to_manipulate_weights_bit_widths .get (n )}
7273 set_quantization_configs_to_node (node = n ,
7374 graph = graph ,
7475 quant_config = quant_config ,
7576 fw_info = graph .fw_info ,
7677 fqc = graph .fqc ,
7778 mixed_precision_enable = mixed_precision_enable ,
78- manual_bit_width_override = nodes_to_manipulate_bit_widths . get ( n ) )
79+ manual_bit_width_override = manual_bit_width_override )
7980 return graph
8081
8182
@@ -151,7 +152,7 @@ def set_quantization_configs_to_node(node: BaseNode,
151152 fw_info : FrameworkInfo ,
152153 fqc : FrameworkQuantizationCapabilities ,
153154 mixed_precision_enable : bool = False ,
154- manual_bit_width_override : Optional [int ] = None ):
155+ manual_bit_width_override : Optional [Dict ] = None ):
155156 """
156157 Create and set quantization configurations to a node (for both weights and activation).
157158
@@ -167,8 +168,11 @@ def set_quantization_configs_to_node(node: BaseNode,
167168 node_qc_options = node .get_qco (fqc )
168169 base_config , node_qc_options_list = filter_node_qco_by_graph (node , fqc , graph , node_qc_options )
169170
170- # If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
171+ # If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation and weights bits equal to manual_bit_width_override,
171172 # and update base_config accordingly.
173+ if manual_bit_width_override is None :
174+ manual_bit_width_override = {ACTIVATION : None , WEIGHTS : None }
175+
172176 base_config , node_qc_options_list = filter_qc_options_with_manual_bit_width (
173177 node = node ,
174178 node_qc_options_list = node_qc_options_list ,
@@ -322,7 +326,7 @@ def filter_qc_options_with_manual_bit_width(
322326 node : BaseNode ,
323327 node_qc_options_list : List [OpQuantizationConfig ],
324328 base_config : OpQuantizationConfig ,
325- manual_bit_width_override : Optional [int ],
329+ manual_bit_width_override : Optional [Dict ],
326330 mixed_precision_enable : bool ) -> Tuple [OpQuantizationConfig , List [OpQuantizationConfig ]]:
327331 """
328332 Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
@@ -331,36 +335,169 @@ def filter_qc_options_with_manual_bit_width(
331335 node (BaseNode): A node to set quantization configuration candidates to.
332336 node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
333337 base_config (OpQuantizationConfig): Base quantization config for the node.
334- manual_bit_width_override (Optional[int ]): Specifies a custom bit-width to override the node's activation bit-width.
338+ manual_bit_width_override (Optional[Dict ]): Specifies a custom bit-width to override the node's activation and weights bit-width.
335339 mixed_precision_enable (bool): Whether mixed precision is enabled.
336340
337341 Returns:
338342 Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
339343 """
340- if manual_bit_width_override is None :
344+ base_config , node_qc_options_list = filter_activation_qc_options_with_manual_bit_width (node ,
345+ node_qc_options_list ,
346+ base_config ,
347+ manual_bit_width_override .get (ACTIVATION ),
348+ mixed_precision_enable )
349+
350+ base_config , node_qc_options_list = filter_weights_qc_options_with_manual_bit_width (node ,
351+ node_qc_options_list ,
352+ base_config ,
353+ manual_bit_width_override .get (WEIGHTS ),
354+ mixed_precision_enable )
355+ return base_config , node_qc_options_list
356+
357+
358+ def filter_activation_qc_options_with_manual_bit_width (
359+ node : BaseNode ,
360+ node_qc_options_list : List [OpQuantizationConfig ],
361+ base_config : OpQuantizationConfig ,
362+ activation_manual_bit_width_override : Optional [int ],
363+ mixed_precision_enable : bool ) -> Tuple [OpQuantizationConfig , List [OpQuantizationConfig ]]:
364+ """
365+ Update the activation quantization configurations for a node, allowing manual bit-width overrides if specified.
366+
367+ Args:
368+ node (BaseNode): A node to set quantization configuration candidates to.
369+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
370+ base_config (OpQuantizationConfig): Base quantization config for the node.
371+ activation_manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation bit-width.
372+ mixed_precision_enable (bool): Whether mixed precision is enabled.
373+
374+ Returns:
375+ Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
376+ """
377+ if activation_manual_bit_width_override is None :
341378 return base_config , node_qc_options_list
342379
343- # Filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override .
380+ # Filter node_qc_options_list to retain only the options with activation bits equal to activation_manual_bit_width_override .
344381 node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
345- manual_bit_width_override == op_cfg .activation_n_bits ]
346-
382+ activation_manual_bit_width_override == op_cfg .activation_n_bits ]
347383 if len (node_qc_options_list ) == 0 :
348- Logger .critical (f"Manually selected activation bit-width { manual_bit_width_override } is invalid for node { node } ." )
384+ Logger .critical (f"Manually selected activation bit-width { activation_manual_bit_width_override } is invalid for node { node } ." )
349385 else :
350386 # Update the base_config to one of the values from the filtered node_qc_options_list.
351- # First, check if a configuration similar to the original base_config but with activation bits equal to manual_bit_width_override exists.
387+ # First, check if a configuration similar to the original base_config but with activation bits equal to activation_manual_bit_width_override exists.
352388 # If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
353- Logger .info (f"Setting node { node } bit-width to manually selected bit-width: { manual_bit_width_override } bits." )
354- updated_base_config = base_config .clone_and_edit ({ACTIVATION_N_BITS , manual_bit_width_override })
389+ Logger .info (f"Setting node { node } bit-width to manually selected bit-width: { activation_manual_bit_width_override } bits." )
390+ updated_base_config = base_config .clone_and_edit ({ACTIVATION_N_BITS , activation_manual_bit_width_override })
355391 if updated_base_config in node_qc_options_list :
356- # If a base_config with the specified manual_bit_width_override exists in the node_qc_options_list,
392+ # If a base_config with the specified activation_manual_bit_width_override exists in the node_qc_options_list,
357393 # point the base_config to this option.
358394 base_config = node_qc_options_list [node_qc_options_list .index (updated_base_config )]
359395 else :
360396 # Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
361397 base_config = node_qc_options_list [0 ]
362398 if len (node_qc_options_list ) > 0 and not mixed_precision_enable :
363399 Logger .info (
364- f"Request received to select { manual_bit_width_override } activation bits. However, the base configuration for layer type { node .type } is missing in the node_qc_options_list."
365- f" Overriding base_config with an option that uses { manual_bit_width_override } bit activations." ) # pragma: no cover
366- return base_config , node_qc_options_list
400+ f"Request received to select { activation_manual_bit_width_override } activation bits. However, the base configuration for layer type { node .type } is missing in the node_qc_options_list."
401+ f" Overriding base_config with an option that uses { activation_manual_bit_width_override } bit activations." ) # pragma: no cover
402+
403+ return base_config , node_qc_options_list
404+
405+
406+ def filter_weights_qc_options_with_manual_bit_width (
407+ node : BaseNode ,
408+ node_qc_options_list : List [OpQuantizationConfig ],
409+ base_config : OpQuantizationConfig ,
410+ weights_manual_bit_width_override : Optional [Tuple [int , WeightAttrT ]],
411+ mixed_precision_enable : bool ) -> Tuple [OpQuantizationConfig , List [OpQuantizationConfig ]]:
412+ """
413+ Update the weights quantization configurations for a node, allowing manual bit-width overrides if specified.
414+
415+ Args:
416+ node (BaseNode): A node to set quantization configuration candidates to.
417+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
418+ base_config (OpQuantizationConfig): Base quantization config for the node.
419+ weights_manual_bit_width_override (Optional[[int, WeightAttrT]]): Specifies a custom bit-width to override the node's weights bit-width.
420+ mixed_precision_enable (bool): Whether mixed precision is enabled.
421+
422+ Returns:
423+ Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
424+ """
425+ if not weights_manual_bit_width_override :
426+ return base_config , node_qc_options_list
427+
428+ # Filter node_qc_options_list to retain only the options with weights bits equal to weights_manual_bit_width_override.
429+ node_qc_options_weights_list = _filter_options (node_qc_options_list , weights_manual_bit_width_override )
430+
431+ if len (node_qc_options_weights_list ) == 0 :
432+ Logger .critical (f"Manually selected weights bit-width { weights_manual_bit_width_override } is invalid for node { node } ." )
433+ else :
434+ # Update the base_config to one of the values from the filtered node_qc_options_list.
435+ # First, check if a configuration similar to the original base_config but with weights bits equal to weights_manual_bit_width_override exists.
436+ # If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
437+ updated_base_config = base_config .clone_and_edit ()
438+
439+ for bit_width , attr in weights_manual_bit_width_override :
440+ Logger .info (f"Setting node { node } bit-width to manually selected { attr } bit-width: { bit_width } bits." )
441+ updated_base_config = updated_base_config .clone_and_edit (attr_to_edit = {attr : {WEIGHTS_N_BITS : bit_width }})
442+
443+ if updated_base_config in node_qc_options_weights_list :
444+ # If a base_config with the specified weights_manual_bit_width_override exists in the node_qc_options_list,
445+ # point the base_config to this option.
446+ base_config = node_qc_options_weights_list [node_qc_options_weights_list .index (updated_base_config )]
447+ else :
448+ # Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
449+ base_config = node_qc_options_weights_list [0 ]
450+ if len (node_qc_options_weights_list ) > 0 and not mixed_precision_enable :
451+ Logger .info (
452+ f"Request received to select weights bit-widths { weights_manual_bit_width_override } ."
453+ f"However, the base configuration for layer type { node .type } is missing in the node_qc_options_list."
454+ f" Overriding base_config with an option that uses manually selected weights bit-widths { weights_manual_bit_width_override } ." ) # pragma: no cover
455+
456+ return base_config , node_qc_options_weights_list
457+
458+
459+ def _is_valid_option (
460+ op_cfg : OpQuantizationConfig ,
461+ attr : WeightAttrT ,
462+ bit_width : int ) -> bool :
463+ """
464+ Judge whether the specified option is valid based on the specified attribute and bit width.
465+
466+ Args:
467+ op_cfg (OpQuantizationConfig): The quantization configuration to be judged.
468+ attr (WeightAttrT): The filtered node's attributes to apply bit-width manipulation to.
469+ bit_width (int): The bit width to be applied to the selected nodes.
470+
471+ Returns:
472+ Result to judge whether the specified option is valid based on the specified attribute and bit width
473+ """
474+ weights_attrs = op_cfg .attr_weights_configs_mapping .keys ()
475+
476+ if attr not in weights_attrs :
477+ return False
478+
479+ weights_n_bits = op_cfg .attr_weights_configs_mapping [attr ].weights_n_bits
480+ return weights_n_bits == bit_width
481+
482+
483+ def _filter_options (
484+ node_qc_options_list : List [OpQuantizationConfig ],
485+ weights_manual_bit_width_override : Tuple [int , WeightAttrT ]) -> List [OpQuantizationConfig ]:
486+ """
487+ Filter the options based on the specified bit width and attribute.
488+
489+ Args:
490+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
491+ weights_manual_bit_width_override (Tuple[int, WeightAttrT])): Specifies a custom bit-width to override the node's weights bit-width.
492+
493+ Returns:
494+ List[OpQuantizationConfig]: Filtered the options based on the specified bit width and attribute.
495+ """
496+ filtered_options = []
497+
498+ for bit_width , attr in weights_manual_bit_width_override :
499+ for op_cfg in node_qc_options_list :
500+ if _is_valid_option (op_cfg , attr , bit_width ):
501+ filtered_options .append (op_cfg )
502+
503+ return filtered_options
0 commit comments