Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a81fd1e
pytorch_activation_threshold_search(Debugging and exploration code re…
gouda-youichi Jan 28, 2025
4c66805
Merge branch 'sony:main' into main
gouda-youichi Jan 29, 2025
66a3bd9
pytorch_activation_threshold_search(refine a little)
gouda-youichi Jan 29, 2025
ffe0e6d
Merge branch 'main' of https://github.com/gouda-youichi/model_optimiz…
gouda-youichi Jan 29, 2025
746cb57
Fixed ofirgo-san's review comment.
gouda-youichi Jan 31, 2025
8128264
Fixed ofirgo-san's review comment.[2]
gouda-youichi Jan 31, 2025
67b78b5
Fixed ofirgo-san's review comment.[3]
gouda-youichi Jan 31, 2025
722380b
PR comment correction (in progress)
gouda-youichi Feb 3, 2025
5564182
[DEBUG]about printout layer_type
gouda-youichi Feb 3, 2025
d5481a0
PR comment correction(fixed)
gouda-youichi Feb 4, 2025
2f6e5fc
Merge branch 'sony:main' into main
gouda-youichi Feb 5, 2025
947f56c
Merge branch 'main' of https://github.com/gouda-youichi/model_optimiz…
gouda-youichi Feb 27, 2025
0926861
Merge branch 'main' of https://github.com/gouda-youichi/model_optimiz…
gouda-youichi Feb 27, 2025
7d4ba38
add set_manual_weights_bit_width to bit_width_config.py
gouda-youichi Feb 28, 2025
6eec34b
modified get_nodes_to_manipulate_bit_widths and etc
gouda-youichi Feb 28, 2025
6c3c262
All implementations completed. Not yet tested.
gouda-youichi Mar 4, 2025
50af865
Merge branch 'sony:main' into adding_weights_manual_selection_bitwidth
gouda-youichi Mar 6, 2025
9411042
revert original setting for tpc.
gouda-youichi Mar 6, 2025
b283c81
add test for weights_manual_selection_bitwidth
gouda-youichi Mar 6, 2025
a776f4c
modified test_manual_weights_bitwidth_selection.py
gouda-youichi Mar 11, 2025
424dde6
revert tpc.py
gouda-youichi Mar 12, 2025
b07fc5d
correcting accrding to the feedback comments
gouda-youichi Mar 12, 2025
02dce0c
correcting according to the feedback comments
gouda-youichi Mar 12, 2025
464f1da
correcting according to the feedback comments.
gouda-youichi Mar 12, 2025
bc0cf98
Revert unnecessary modifications
gouda-youichi Mar 12, 2025
70db8e8
Revert set_node_quantization_config.py
gouda-youichi Mar 12, 2025
c51a200
fixing for manual weights selection bitwidth(kernel,bias)
gouda-youichi Mar 13, 2025
37d2dcf
fixing for manual weights selection bitwidth(kernel,bias)
gouda-youichi Mar 13, 2025
d411b8d
fixed PR-FB for manual weights selection bitwidth
gouda-youichi Mar 14, 2025
5b88fe2
rename test script
gouda-youichi Mar 14, 2025
1f30c99
Correcting comments on pull requests
gouda-youichi Mar 19, 2025
8ee3002
Correcting comments on pull requests_2
gouda-youichi Mar 19, 2025
2a3c40c
Fixed comments on pull requests.
gouda-youichi Mar 19, 2025
1af4e11
Merge branch 'sony:main' into adding_weights_manual_selection_bitwidth
kawakami-masaki0 Mar 24, 2025
7ba5a17
fixing for manual weights selection bitwidth
kawakami-masaki0 Mar 24, 2025
afe57b4
add __init__.py
kawakami-masaki0 Mar 24, 2025
c49ba74
fix imported module
kawakami-masaki0 Mar 24, 2025
e7737d6
fixed to check for integer or string
kawakami-masaki0 Mar 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
# In Mixed-Precision, a node can have multiple candidates for weights and activations quantization configuration.
# In order to display a single view of a node (for example, for logging in TensorBoard) we need to track the attributes
# that are shared among different candidates:
WEIGHTS_ATTRIBUTE = 'weights'
ACTIVATION_ATTRIBUTE = 'activation'
WEIGHTS_NBITS_ATTRIBUTE = 'weights_n_bits'
CORRECTED_BIAS_ATTRIBUTE = 'corrected_bias'
ACTIVATION_N_BITS_ATTRIBUTE = 'activation_n_bits'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def apply(self, input_node_object: BaseNode) -> bool:
if input_node_object.is_match_type(self.operation):
return True


class NodeFrameworkAttrMatcher(node_matcher.BaseNodeMatcher):
"""
Class NodeFrameworkAttrMatcher to check if a node's attribute has a specific value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from typing import List, Union, Dict

from model_compression_toolkit.constants import WEIGHTS_ATTRIBUTE, ACTIVATION_ATTRIBUTE
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
from model_compression_toolkit.logger import Logger
Expand Down Expand Up @@ -42,6 +43,7 @@ class BitWidthConfig:
manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects defining manual bit-width configurations.
"""
manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = field(default_factory=list)
manual_weights_bit_width_selection_list: List[ManualBitWidthSelection] = field(default_factory=list)

def set_manual_activation_bit_width(self,
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
Expand All @@ -65,6 +67,28 @@ def set_manual_activation_bit_width(self,
for bit_width, filter in zip (bit_widths, filters):
self.manual_activation_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]

def set_manual_weights_bit_width(self,
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
bit_widths: Union[List[int], int]):
"""
Add a manual bit-width selection to the configuration.

Args:
filter (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
bit_width (Union[List[int], int]): The bit widths to be applied to the selected nodes.
If a single value is given it will be applied to all the filters
"""
filters = [filters] if not isinstance(filters, list) else filters
bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths
if len(bit_widths) > 1 and len(bit_widths) != len(filters):
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(bit_widths)} "
f"must match the number of filters {len(filters)}, or a single bit_width value "
f"should be provided for all filters.")
elif len(bit_widths) == 1 and len(filters) > 1:
bit_widths = [bit_widths[0] for f in filters]
for bit_width, filter in zip (bit_widths, filters):
self.manual_weights_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]

def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict:
"""
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
Expand All @@ -75,16 +99,25 @@ def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict:
Returns:
Dict: A dictionary mapping nodes to their new bit-widths.
"""
nodes_to_change_bit_width = {}
for manual_bit_width_selection in self.manual_activation_bit_width_selection_list:
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
if len(filtered_nodes) == 0:
Logger.critical(f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
for n in filtered_nodes:
# check if a manual configuration exists for this node
if n in nodes_to_change_bit_width:
Logger.info(
f"Node {n} has an existing manual bit width configuration of {nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
def make_nodes_to_change_bit_width(manual_bit_width_selection_list):
unit_nodes_to_change_bit_width = {}
for manual_bit_width_selection in manual_bit_width_selection_list:
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
if len(filtered_nodes) == 0:
Logger.critical(
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
for n in filtered_nodes:
# check if a manual configuration exists for this node
if n in unit_nodes_to_change_bit_width:
Logger.info(
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})

return unit_nodes_to_change_bit_width

activation_nodes_to_change_bit_width = make_nodes_to_change_bit_width(self.manual_activation_bit_width_selection_list)
weights_nodes_to_change_bit_width = make_nodes_to_change_bit_width(self.manual_weights_bit_width_selection_list)

nodes_to_change_bit_width = {ACTIVATION_ATTRIBUTE: activation_nodes_to_change_bit_width, WEIGHTS_ATTRIBUTE: weights_nodes_to_change_bit_width}
return nodes_to_change_bit_width
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@


import copy
from typing import List, Tuple, Optional
from typing import List, Tuple, Dict, Optional

from mct_quantizers.common.constants import ACTIVATION_N_BITS
from mct_quantizers.common.constants import WEIGHTS_N_BITS, ACTIVATION_N_BITS
from model_compression_toolkit.constants import WEIGHTS_ATTRIBUTE, ACTIVATION_ATTRIBUTE
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
from model_compression_toolkit.logger import Logger
Expand All @@ -32,6 +33,7 @@
get_activation_quantization_params_fn, get_weights_quantization_params_fn
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
get_weights_quantization_fn
from model_compression_toolkit.target_platform_capabilities.constants import BIAS, KERNEL_ATTR
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
QuantizationConfigOptions
Expand Down Expand Up @@ -68,13 +70,15 @@ def set_quantization_configuration_to_graph(graph: Graph,
nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph)

for n in graph.nodes:
manual_bit_width_override = {ACTIVATION_ATTRIBUTE: nodes_to_manipulate_bit_widths.get(ACTIVATION_ATTRIBUTE).get(n),
WEIGHTS_ATTRIBUTE: nodes_to_manipulate_bit_widths.get(WEIGHTS_ATTRIBUTE).get(n)}
set_quantization_configs_to_node(node=n,
graph=graph,
quant_config=quant_config,
fw_info=graph.fw_info,
fqc=graph.fqc,
mixed_precision_enable=mixed_precision_enable,
manual_bit_width_override=nodes_to_manipulate_bit_widths.get(n))
manual_bit_width_override=manual_bit_width_override)
return graph


Expand Down Expand Up @@ -150,7 +154,7 @@ def set_quantization_configs_to_node(node: BaseNode,
fw_info: FrameworkInfo,
fqc: FrameworkQuantizationCapabilities,
mixed_precision_enable: bool = False,
manual_bit_width_override: Optional[int] = None):
manual_bit_width_override: Optional[Dict] = None):
"""
Create and set quantization configurations to a node (for both weights and activation).

Expand Down Expand Up @@ -320,7 +324,7 @@ def filter_qc_options_with_manual_bit_width(
node: BaseNode,
node_qc_options_list: List[OpQuantizationConfig],
base_config: OpQuantizationConfig,
manual_bit_width_override: Optional[int],
manual_bit_width_override: Optional[Dict],
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
"""
Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
Expand All @@ -329,36 +333,98 @@ def filter_qc_options_with_manual_bit_width(
node (BaseNode): A node to set quantization configuration candidates to.
node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
base_config (OpQuantizationConfig): Base quantization config for the node.
manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width.
manual_bit_width_override (Optional[Dict]): Specifies a custom bit-width to override the node's activation and weights bit-width.
mixed_precision_enable (bool): Whether mixed precision is enabled.

Returns:
Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
"""
if manual_bit_width_override is None:
base_config, node_qc_options_list = activation_qc_options_with_manual_bit_width(node,
node_qc_options_list,
base_config,
manual_bit_width_override.get(ACTIVATION_ATTRIBUTE),
mixed_precision_enable)
base_config, node_qc_options_list = weights_qc_options_with_manual_bit_width(node,
node_qc_options_list,
base_config,
manual_bit_width_override.get(WEIGHTS_ATTRIBUTE),
mixed_precision_enable)
return base_config, node_qc_options_list


def activation_qc_options_with_manual_bit_width(
node: BaseNode,
node_qc_options_list: List[OpQuantizationConfig],
base_config: OpQuantizationConfig,
activation_manual_bit_width_override: Optional[int],
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:

if activation_manual_bit_width_override is None:
return base_config, node_qc_options_list

# Filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override.
# Filter node_qc_options_list to retain only the options with activation bits equal to activation_manual_bit_width_override.
node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
manual_bit_width_override == op_cfg.activation_n_bits]

activation_manual_bit_width_override == op_cfg.activation_n_bits]
if len(node_qc_options_list) == 0:
Logger.critical(f"Manually selected activation bit-width {manual_bit_width_override} is invalid for node {node}.")
Logger.critical(f"Manually selected activation bit-width {activation_manual_bit_width_override} is invalid for node {node}.")
else:
# Update the base_config to one of the values from the filtered node_qc_options_list.
# First, check if a configuration similar to the original base_config but with activation bits equal to manual_bit_width_override exists.
# First, check if a configuration similar to the original base_config but with activation bits equal to activation_manual_bit_width_override exists.
# If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {manual_bit_width_override} bits.")
updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, manual_bit_width_override})
Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {activation_manual_bit_width_override} bits.")
updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, activation_manual_bit_width_override})
if updated_base_config in node_qc_options_list:
# If a base_config with the specified manual_bit_width_override exists in the node_qc_options_list,
# If a base_config with the specified activation_manual_bit_width_override exists in the node_qc_options_list,
# point the base_config to this option.
base_config = node_qc_options_list[node_qc_options_list.index(updated_base_config)]
else:
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
base_config = node_qc_options_list[0]
if len(node_qc_options_list) > 0 and not mixed_precision_enable:
Logger.info(
f"Request received to select {manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
f" Overriding base_config with an option that uses {manual_bit_width_override} bit activations.") # pragma: no cover
return base_config, node_qc_options_list
f"Request received to select {activation_manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
f" Overriding base_config with an option that uses {activation_manual_bit_width_override} bit activations.") # pragma: no cover

return base_config, node_qc_options_list


def weights_qc_options_with_manual_bit_width(
node: BaseNode,
node_qc_options_list: List[OpQuantizationConfig],
base_config: OpQuantizationConfig,
weights_manual_bit_width_override: Optional[int],
mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:

if weights_manual_bit_width_override is None:
return base_config, node_qc_options_list

# Filter node_qc_options_list to retain only the options with weights bits equal to weights_manual_bit_width_override.
node_qc_options_weights_list, target_key_list = [], []
for op_cfg in node_qc_options_list:
for weights_attrs in op_cfg.attr_weights_configs_mapping.keys():
if weights_manual_bit_width_override == op_cfg.attr_weights_configs_mapping.get(weights_attrs).weights_n_bits:
node_qc_options_weights_list.append(op_cfg)
target_key_list.append(weights_attrs)

if len(node_qc_options_weights_list) == 0 or len(target_key_list) == 0:
Logger.critical(f"Manually selected weights bit-width {weights_manual_bit_width_override} is invalid for node {node}.")
else:
# Update the base_config to one of the values from the filtered node_qc_options_list.
# First, check if a configuration similar to the original base_config but with activation bits equal to weights_manual_bit_width_override exists.
# If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
for target_key in target_key_list:
Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {weights_manual_bit_width_override} bits.")
updated_base_config = base_config.clone_and_edit(attr_to_edit={target_key : {WEIGHTS_N_BITS: weights_manual_bit_width_override}})
if updated_base_config in node_qc_options_weights_list:
# If a base_config with the specified weights_manual_bit_width_override exists in the node_qc_options_list,
# point the base_config to this option.
base_config = node_qc_options_weights_list[node_qc_options_weights_list.index(updated_base_config)]
else:
# Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
base_config = node_qc_options_weights_list[0]
if len(node_qc_options_weights_list) > 0 and not mixed_precision_enable:
Logger.info(
f"Request received to select {weights_manual_bit_width_override} weights bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
f" Overriding base_config with an option that uses {weights_manual_bit_width_override} bit widths.") # pragma: no cover

return base_config, node_qc_options_weights_list
Loading
Loading