Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a81fd1e
pytorch_activation_threshold_search(Debugging and exploration code re…
gouda-youichi Jan 28, 2025
4c66805
Merge branch 'sony:main' into main
gouda-youichi Jan 29, 2025
66a3bd9
pytorch_activation_threshold_search(refine a little)
gouda-youichi Jan 29, 2025
ffe0e6d
Merge branch 'main' of https://github.com/gouda-youichi/model_optimiz…
gouda-youichi Jan 29, 2025
746cb57
Fixed ofirgo-san's review comment.
gouda-youichi Jan 31, 2025
8128264
Fixed ofirgo-san's review comment.[2]
gouda-youichi Jan 31, 2025
67b78b5
Fixed ofirgo-san's review comment.[3]
gouda-youichi Jan 31, 2025
722380b
PR comment correction (in progress)
gouda-youichi Feb 3, 2025
5564182
[DEBUG]about printout layer_type
gouda-youichi Feb 3, 2025
d5481a0
PR comment correction(fixed)
gouda-youichi Feb 4, 2025
2f6e5fc
Merge branch 'sony:main' into main
gouda-youichi Feb 5, 2025
947f56c
Merge branch 'main' of https://github.com/gouda-youichi/model_optimiz…
gouda-youichi Feb 27, 2025
0926861
Merge branch 'main' of https://github.com/gouda-youichi/model_optimiz…
gouda-youichi Feb 27, 2025
7d4ba38
add set_manual_weights_bit_width to bit_width_config.py
gouda-youichi Feb 28, 2025
6eec34b
modified get_nodes_to_manipulate_bit_widths and etc
gouda-youichi Feb 28, 2025
6c3c262
All implementations completed. Not yet tested.
gouda-youichi Mar 4, 2025
50af865
Merge branch 'sony:main' into adding_weights_manual_selection_bitwidth
gouda-youichi Mar 6, 2025
9411042
revert original setting for tpc.
gouda-youichi Mar 6, 2025
b283c81
add test for weights_manual_selection_bitwidth
gouda-youichi Mar 6, 2025
a776f4c
modified test_manual_weights_bitwidth_selection.py
gouda-youichi Mar 11, 2025
424dde6
revert tpc.py
gouda-youichi Mar 12, 2025
b07fc5d
correcting accrding to the feedback comments
gouda-youichi Mar 12, 2025
02dce0c
correcting according to the feedback comments
gouda-youichi Mar 12, 2025
464f1da
correcting according to the feedback comments.
gouda-youichi Mar 12, 2025
bc0cf98
Revert unnecessary modifications
gouda-youichi Mar 12, 2025
70db8e8
Revert set_node_quantization_config.py
gouda-youichi Mar 12, 2025
c51a200
fixing for manual weights selection bitwidth(kernel,bias)
gouda-youichi Mar 13, 2025
37d2dcf
fixing for manual weights selection bitwidth(kernel,bias)
gouda-youichi Mar 13, 2025
d411b8d
fixed PR-FB for manual weights selection bitwidth
gouda-youichi Mar 14, 2025
5b88fe2
rename test script
gouda-youichi Mar 14, 2025
1f30c99
Correcting comments on pull requests
gouda-youichi Mar 19, 2025
8ee3002
Correcting comments on pull requests_2
gouda-youichi Mar 19, 2025
2a3c40c
Fixed comments on pull requests.
gouda-youichi Mar 19, 2025
1af4e11
Merge branch 'sony:main' into adding_weights_manual_selection_bitwidth
kawakami-masaki0 Mar 24, 2025
7ba5a17
fixing for manual weights selection bitwidth
kawakami-masaki0 Mar 24, 2025
afe57b4
add __init__.py
kawakami-masaki0 Mar 24, 2025
c49ba74
fix imported module
kawakami-masaki0 Mar 24, 2025
e7737d6
fixed to check for integer or string
kawakami-masaki0 Mar 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
# In Mixed-Precision, a node can have multiple candidates for weights and activations quantization configuration.
# In order to display a single view of a node (for example, for logging in TensorBoard) we need to track the attributes
# that are shared among different candidates:
WEIGHTS_ATTRIBUTE = 'weights'
ACTIVATION_ATTRIBUTE = 'activation'
WEIGHTS_NBITS_ATTRIBUTE = 'weights_n_bits'
CORRECTED_BIAS_ATTRIBUTE = 'corrected_bias'
ACTIVATION_N_BITS_ATTRIBUTE = 'activation_n_bits'
Expand Down
204 changes: 178 additions & 26 deletions model_compression_toolkit/core/common/quantization/bit_width_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,72 +19,224 @@
from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher
from model_compression_toolkit.logger import Logger

from model_compression_toolkit.core.common.graph.base_node import WeightAttrT

@dataclass
class ManualBitWidthSelection:
"""
Class to encapsulate the manual bit width selection configuration for a specific filter.
Class to encapsulate the manual bit width selection configuration for a specific filter.

Attributes:
Attributes:
filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
bit_width (int): The bit width to be applied to the selected nodes.
"""
"""
filter: BaseNodeMatcher
bit_width: int

@dataclass
class ManualWeightsBitWidthSelection(ManualBitWidthSelection):
"""
Class to encapsulate the manual weights bit width selection configuration for a specific filter.

Attributes:
filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation.
bit_width (int): The bit width to be applied to the selected nodes.
attr (str): The filtered node's attributes to apply bit-width manipulation to.
"""
attr: WeightAttrT

@dataclass
class BitWidthConfig:
"""
Class to manage manual bit-width configurations.

Attributes:
manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects defining manual bit-width configurations.
manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects for activation defining manual bit-width configurations.
manual_weights_bit_width_selection_list (List[ManualWeightsBitWidthSelection]): A list of ManualWeightsBitWidthSelection for weights objects defining manual bit-width configurations.
"""
manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = field(default_factory=list)
manual_weights_bit_width_selection_list: List[ManualWeightsBitWidthSelection] = field(default_factory=list)

def set_manual_activation_bit_width(self,
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
bit_widths: Union[List[int], int]):
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
bit_widths: Union[List[int], int]):
"""
Add a manual bit-width selection to the configuration.
Add a manual bit-width selection for activation to the configuration.

Args:
filter (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
bit_width (Union[List[int], int]): The bit widths to be applied to the selected nodes.
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes.
If a single value is given it will be applied to all the filters
"""
filters = [filters] if not isinstance(filters, list) else filters
bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths
if len(bit_widths) > 1 and len(bit_widths) != len(filters):
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(bit_widths)} "
f"must match the number of filters {len(filters)}, or a single bit_width value "
f"should be provided for all filters.")
elif len(bit_widths) == 1 and len(filters) > 1:
bit_widths = [bit_widths[0] for f in filters]
if filters is None:
Logger.critical(f"The filters cannot be None.")
_, bit_widths, filters = self._expand_to_list(filters, bit_widths)
for bit_width, filter in zip (bit_widths, filters):
self.manual_activation_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)]

def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict:
def set_manual_weights_bit_width(self,
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
bit_widths: Union[List[int], int],
attrs: Union[List[WeightAttrT], WeightAttrT]):
"""
Add a manual bit-width selection for weights to the configuration.

Args:
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes.
attrs (Union[List[WeightAttrT], WeightAttrT]): The filtered node's attributes to apply bit-width manipulation to.
If a single value is given it will be applied to all the filters
"""
if filters is None:
Logger.critical(f"The filters cannot be None.")
attrs, bit_widths, filters = self._expand_to_list(filters, bit_widths, attrs)
for attr, bit_width, filter in zip (attrs, bit_widths, filters):
self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, bit_width, attr)]

def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict:
"""
Retrieve nodes from the graph that need their bit-widths for activation changed according to the manual bit-width selections.

Args:
graph (Graph): The graph containing the nodes to be filtered and manipulated.

Returns:
Dict: A dictionary mapping nodes to their new bit-widths.
"""
activation_nodes_to_change_bit_width = self._construct_node_to_new_activation_bit_mapping(graph)
return activation_nodes_to_change_bit_width

def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict:
"""
Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections.
Retrieve nodes from the graph that need their bit-widths for weights changed according to the manual bit-width selections.

Args:
graph (Graph): The graph containing the nodes to be filtered and manipulated.

Returns:
Dict: A dictionary mapping nodes to their new bit-widths.
"""
nodes_to_change_bit_width = {}
weights_nodes_to_change_bit_width = self._construct_node_to_new_weights_bit_mapping(graph)
return weights_nodes_to_change_bit_width

@staticmethod
def _expand_to_list_core(
filters: Union[List[BaseNodeMatcher], BaseNodeMatcher],
vals: Union[List[Union[WeightAttrT, int]], Union[WeightAttrT, int]]) -> list:
"""
Extend the length of vals to match the length of filters.

Args:
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
vals Union[List[Union[WeightAttrT, int], Union[WeightAttrT, int]]]): The bit widths or The filtered node's attributes.

Returns:
list: Extended vals to match the length of filters.
"""
vals = [vals] if not isinstance(vals, list) else vals
if len(vals) > 1 and len(vals) != len(filters):
Logger.critical(f"Configuration Error: The number of provided bit_width values {len(vals)} "
f"must match the number of filters {len(filters)}, or a single bit_width value "
f"should be provided for all filters.")
elif len(vals) == 1 and len(filters) > 1:
vals = [vals[0] for f in filters]
return vals

@staticmethod
def _expand_to_list(
filters: Union[List[BaseNodeMatcher]],
bit_widths: Union[List[int], int],
attrs: Union[List[WeightAttrT], WeightAttrT] = None) -> [List]:
"""
Extend the length of filters, bit-widths and The filtered node's attributes to match the length of filters.

Args:
filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation.
bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes.
attrs (Union[List[WeightAttrT], WeightAttrT]): The filtered node's attributes to apply bit-width manipulation to.

Returns:
[List]: A List of extended input arguments.
"""
filters = [filters] if not isinstance(filters, list) else filters
bit_widths = BitWidthConfig._expand_to_list_core(filters, bit_widths)
if attrs is not None:
attrs = BitWidthConfig._expand_to_list_core(filters, attrs)
return attrs, bit_widths, filters

def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict:
"""
Retrieve nodes from the graph that need their activation bit-widths changed according to the manual bit-width selections.

Args:
graph (Graph): The graph containing the nodes to be filtered and manipulated.

Returns:
Dict: A dictionary retrieved nodes from the graph.
"""
unit_nodes_to_change_bit_width = {}
for manual_bit_width_selection in self.manual_activation_bit_width_selection_list:
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
if len(filtered_nodes) == 0:
Logger.critical(f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
Logger.critical(
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
f"to change their bit width to {manual_bit_width_selection.bit_width}.")
for n in filtered_nodes:
# check if a manual configuration exists for this node
if n in nodes_to_change_bit_width:
if n in unit_nodes_to_change_bit_width:
Logger.info(
f"Node {n} has an existing manual bit width configuration of {nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
return nodes_to_change_bit_width
f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}."
f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")
unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width})
return unit_nodes_to_change_bit_width

def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict:
"""
Retrieve nodes from the graph that need their weights bit-widths changed according to the manual bit-width selections.

Args:
graph (Graph): The graph containing the nodes to be filtered and manipulated.

Returns:
Dict: A dictionary retrieved nodes from the graph.
"""
unit_nodes_to_change_bit_width = {}

for manual_bit_width_selection in self.manual_weights_bit_width_selection_list:
filtered_nodes = graph.filter(manual_bit_width_selection.filter)
if len(filtered_nodes) == 0:
Logger.critical(
f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} "
f"to change their bit width to {manual_bit_width_selection.bit_width}.")

for n in filtered_nodes:
attr_to_change_bit_width = []

attrs_str = n.get_node_weights_attributes()
if len(attrs_str) == 0:
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')

attr = []
for attr_str in attrs_str:
if isinstance(attr_str, str) and isinstance(manual_bit_width_selection.attr, str):
if attr_str.find(manual_bit_width_selection.attr) != -1:
attr.append(attr_str)
elif isinstance(attr_str, int) and isinstance(manual_bit_width_selection.attr, int):
if attr_str == manual_bit_width_selection.attr:
attr.append(attr_str)
if len(attr) == 0:
Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.')

if n in unit_nodes_to_change_bit_width:
attr_to_change_bit_width = unit_nodes_to_change_bit_width[n]
for i, attr_to_bitwidth in enumerate(attr_to_change_bit_width):
if attr_to_bitwidth[1] == manual_bit_width_selection.attr:
del attr_to_change_bit_width[i]
Logger.info(
f"Node {n} has an existing manual bit width configuration of {manual_bit_width_selection.attr}."
f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.")

attr_to_change_bit_width.append([manual_bit_width_selection.bit_width, manual_bit_width_selection.attr])
unit_nodes_to_change_bit_width.update({n: attr_to_change_bit_width})

return unit_nodes_to_change_bit_width
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def set_quantization_configuration_to_graph(graph: Graph,
Logger.warning("Using the HMSE error method for weights quantization parameters search. "
"Note: This method may significantly increase runtime during the parameter search process.")

nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph)
nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)

for n in graph.nodes:
set_quantization_configs_to_node(node=n,
Expand Down
14 changes: 14 additions & 0 deletions tests_pytest/common_tests/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14 changes: 14 additions & 0 deletions tests_pytest/common_tests/core/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14 changes: 14 additions & 0 deletions tests_pytest/common_tests/core/common/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Loading