Skip to content

Commit 18df904

Browse files
committed
add typehint and description.
1 parent c8885d8 commit 18df904

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

model_compression_toolkit/core/common/quantization/set_node_quantization_config.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def filter_weights_qc_options_with_manual_bit_width(
424424
return base_config, node_qc_options_list
425425

426426
# Filter node_qc_options_list to retain only the options with weights bits equal to weights_manual_bit_width_override.
427-
node_qc_options_weights_list = filter_options(node_qc_options_list, weights_manual_bit_width_override)
427+
node_qc_options_weights_list = _filter_options(node_qc_options_list, weights_manual_bit_width_override)
428428

429429
if len(node_qc_options_weights_list) == 0:
430430
Logger.critical(f"Manually selected weights bit-width {weights_manual_bit_width_override} is invalid for node {node}.")
@@ -454,8 +454,21 @@ def filter_weights_qc_options_with_manual_bit_width(
454454
return base_config, node_qc_options_weights_list
455455

456456

457-
def is_valid_option(op_cfg, attr, bit_width):
458-
# Check if the given option is valid based on the specified attribute and bit width.
457+
def _is_valid_option(
458+
op_cfg: OpQuantizationConfig,
459+
attr: WeightAttrT,
460+
bit_width: int) -> bool:
461+
"""
462+
Judge whether the specified option is valid based on the specified attribute and bit width.
463+
464+
Args:
465+
op_cfg (OpQuantizationConfig): The quantization configuration to be judged.
466+
attr (WeightAttrT): The filtered node's attributes to apply bit-width manipulation to.
467+
bit_width (int): The bit width to be applied to the selected nodes.
468+
469+
Returns:
470+
Result to judge whether the specified option is valid based on the specified attribute and bit width
471+
"""
459472
weights_attrs = op_cfg.attr_weights_configs_mapping.keys()
460473

461474
if attr not in weights_attrs:
@@ -465,13 +478,24 @@ def is_valid_option(op_cfg, attr, bit_width):
465478
return weights_n_bits == bit_width
466479

467480

468-
def filter_options(node_qc_options_list, weights_manual_bit_width_override):
469-
# Filter the options based on the specified bit width and attribute.
481+
def _filter_options(
482+
node_qc_options_list: List[OpQuantizationConfig],
483+
weights_manual_bit_width_override: Tuple[int, WeightAttrT]) -> List[OpQuantizationConfig]:
484+
"""
485+
Filter the options based on the specified bit width and attribute.
486+
487+
Args:
488+
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
489+
weights_manual_bit_width_override (Tuple[int, WeightAttrT])): Specifies a custom bit-width to override the node's weights bit-width.
490+
491+
Returns:
492+
List[OpQuantizationConfig]: Filtered the options based on the specified bit width and attribute.
493+
"""
470494
filtered_options = []
471495

472496
for bit_width, attr in weights_manual_bit_width_override:
473497
for op_cfg in node_qc_options_list:
474-
if is_valid_option(op_cfg, attr, bit_width):
498+
if _is_valid_option(op_cfg, attr, bit_width):
475499
filtered_options.append(op_cfg)
476500

477501
return filtered_options

0 commit comments

Comments
 (0)