diff --git a/model_compression_toolkit/constants.py b/model_compression_toolkit/constants.py index ec335a608..7dc7ca366 100644 --- a/model_compression_toolkit/constants.py +++ b/model_compression_toolkit/constants.py @@ -57,6 +57,8 @@ # In Mixed-Precision, a node can have multiple candidates for weights and activations quantization configuration. # In order to display a single view of a node (for example, for logging in TensorBoard) we need to track the attributes # that are shared among different candidates: +WEIGHTS_ATTRIBUTE = 'weights' +ACTIVATION_ATTRIBUTE = 'activation' WEIGHTS_NBITS_ATTRIBUTE = 'weights_n_bits' CORRECTED_BIAS_ATTRIBUTE = 'corrected_bias' ACTIVATION_N_BITS_ATTRIBUTE = 'activation_n_bits' diff --git a/model_compression_toolkit/core/common/quantization/bit_width_config.py b/model_compression_toolkit/core/common/quantization/bit_width_config.py index 887d828e1..77af61ce1 100644 --- a/model_compression_toolkit/core/common/quantization/bit_width_config.py +++ b/model_compression_toolkit/core/common/quantization/bit_width_config.py @@ -19,19 +19,31 @@ from model_compression_toolkit.core.common.matchers.node_matcher import BaseNodeMatcher from model_compression_toolkit.logger import Logger +from model_compression_toolkit.core.common.graph.base_node import WeightAttrT @dataclass class ManualBitWidthSelection: """ - Class to encapsulate the manual bit width selection configuration for a specific filter. + Class to encapsulate the manual bit width selection configuration for a specific filter. - Attributes: + Attributes: filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation. bit_width (int): The bit width to be applied to the selected nodes. - """ + """ filter: BaseNodeMatcher bit_width: int +@dataclass +class ManualWeightsBitWidthSelection(ManualBitWidthSelection): + """ + Class to encapsulate the manual weights bit width selection configuration for a specific filter. + + Attributes: + filter (BaseNodeMatcher): The filter used to select nodes for bit width manipulation. + bit_width (int): The bit width to be applied to the selected nodes. + attr (str): The filtered node's attributes to apply bit-width manipulation to. + """ + attr: WeightAttrT @dataclass class BitWidthConfig: @@ -39,35 +51,64 @@ class BitWidthConfig: Class to manage manual bit-width configurations. Attributes: - manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects defining manual bit-width configurations. + manual_activation_bit_width_selection_list (List[ManualBitWidthSelection]): A list of ManualBitWidthSelection objects for activation defining manual bit-width configurations. + manual_weights_bit_width_selection_list (List[ManualWeightsBitWidthSelection]): A list of ManualWeightsBitWidthSelection for weights objects defining manual bit-width configurations. """ manual_activation_bit_width_selection_list: List[ManualBitWidthSelection] = field(default_factory=list) + manual_weights_bit_width_selection_list: List[ManualWeightsBitWidthSelection] = field(default_factory=list) def set_manual_activation_bit_width(self, - filters: Union[List[BaseNodeMatcher], BaseNodeMatcher], - bit_widths: Union[List[int], int]): + filters: Union[List[BaseNodeMatcher], BaseNodeMatcher], + bit_widths: Union[List[int], int]): """ - Add a manual bit-width selection to the configuration. + Add a manual bit-width selection for activation to the configuration. Args: - filter (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation. - bit_width (Union[List[int], int]): The bit widths to be applied to the selected nodes. + filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation. + bit_widths (Union[List[int], int]): The bit widths to be applied to the selected nodes. If a single value is given it will be applied to all the filters """ - filters = [filters] if not isinstance(filters, list) else filters - bit_widths = [bit_widths] if not isinstance(bit_widths, list) else bit_widths - if len(bit_widths) > 1 and len(bit_widths) != len(filters): - Logger.critical(f"Configuration Error: The number of provided bit_width values {len(bit_widths)} " - f"must match the number of filters {len(filters)}, or a single bit_width value " - f"should be provided for all filters.") - elif len(bit_widths) == 1 and len(filters) > 1: - bit_widths = [bit_widths[0] for f in filters] + if filters is None: + Logger.critical(f"The filters cannot be None.") + _, bit_widths, filters = self._expand_to_list(filters, bit_widths) for bit_width, filter in zip (bit_widths, filters): self.manual_activation_bit_width_selection_list += [ManualBitWidthSelection(filter, bit_width)] - def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict: + def set_manual_weights_bit_width(self, + filters: Union[List[BaseNodeMatcher], BaseNodeMatcher], + bit_widths: Union[List[int], int], + attrs: Union[List[WeightAttrT], WeightAttrT]): + """ + Add a manual bit-width selection for weights to the configuration. + + Args: + filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation. + bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes. + attrs (Union[List[WeightAttrT], WeightAttrT]): The filtered node's attributes to apply bit-width manipulation to. + If a single value is given it will be applied to all the filters + """ + if filters is None: + Logger.critical(f"The filters cannot be None.") + attrs, bit_widths, filters = self._expand_to_list(filters, bit_widths, attrs) + for attr, bit_width, filter in zip (attrs, bit_widths, filters): + self.manual_weights_bit_width_selection_list += [ManualWeightsBitWidthSelection(filter, bit_width, attr)] + + def get_nodes_to_manipulate_activation_bit_widths(self, graph: Graph) -> Dict: + """ + Retrieve nodes from the graph that need their bit-widths for activation changed according to the manual bit-width selections. + + Args: + graph (Graph): The graph containing the nodes to be filtered and manipulated. + + Returns: + Dict: A dictionary mapping nodes to their new bit-widths. + """ + activation_nodes_to_change_bit_width = self._construct_node_to_new_activation_bit_mapping(graph) + return activation_nodes_to_change_bit_width + + def get_nodes_to_manipulate_weights_bit_widths(self, graph: Graph) -> Dict: """ - Retrieve nodes from the graph that need their bit-widths changed according to the manual bit-width selections. + Retrieve nodes from the graph that need their bit-widths for weights changed according to the manual bit-width selections. Args: graph (Graph): The graph containing the nodes to be filtered and manipulated. @@ -75,16 +116,127 @@ def get_nodes_to_manipulate_bit_widths(self, graph: Graph) -> Dict: Returns: Dict: A dictionary mapping nodes to their new bit-widths. """ - nodes_to_change_bit_width = {} + weights_nodes_to_change_bit_width = self._construct_node_to_new_weights_bit_mapping(graph) + return weights_nodes_to_change_bit_width + + @staticmethod + def _expand_to_list_core( + filters: Union[List[BaseNodeMatcher], BaseNodeMatcher], + vals: Union[List[Union[WeightAttrT, int]], Union[WeightAttrT, int]]) -> list: + """ + Extend the length of vals to match the length of filters. + + Args: + filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation. + vals Union[List[Union[WeightAttrT, int], Union[WeightAttrT, int]]]): The bit widths or The filtered node's attributes. + + Returns: + list: Extended vals to match the length of filters. + """ + vals = [vals] if not isinstance(vals, list) else vals + if len(vals) > 1 and len(vals) != len(filters): + Logger.critical(f"Configuration Error: The number of provided bit_width values {len(vals)} " + f"must match the number of filters {len(filters)}, or a single bit_width value " + f"should be provided for all filters.") + elif len(vals) == 1 and len(filters) > 1: + vals = [vals[0] for f in filters] + return vals + + @staticmethod + def _expand_to_list( + filters: Union[List[BaseNodeMatcher]], + bit_widths: Union[List[int], int], + attrs: Union[List[WeightAttrT], WeightAttrT] = None) -> [List]: + """ + Extend the length of filters, bit-widths and The filtered node's attributes to match the length of filters. + + Args: + filters (Union[List[BaseNodeMatcher], BaseNodeMatcher]): The filters used to select nodes for bit-width manipulation. + bit_widths (Union[List[int], int]): The bit widths for specified by attrs to be applied to the selected nodes. + attrs (Union[List[WeightAttrT], WeightAttrT]): The filtered node's attributes to apply bit-width manipulation to. + + Returns: + [List]: A List of extended input arguments. + """ + filters = [filters] if not isinstance(filters, list) else filters + bit_widths = BitWidthConfig._expand_to_list_core(filters, bit_widths) + if attrs is not None: + attrs = BitWidthConfig._expand_to_list_core(filters, attrs) + return attrs, bit_widths, filters + + def _construct_node_to_new_activation_bit_mapping(self, graph) -> Dict: + """ + Retrieve nodes from the graph that need their activation bit-widths changed according to the manual bit-width selections. + + Args: + graph (Graph): The graph containing the nodes to be filtered and manipulated. + + Returns: + Dict: A dictionary retrieved nodes from the graph. + """ + unit_nodes_to_change_bit_width = {} for manual_bit_width_selection in self.manual_activation_bit_width_selection_list: filtered_nodes = graph.filter(manual_bit_width_selection.filter) if len(filtered_nodes) == 0: - Logger.critical(f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} " - f"to change their bit width to {manual_bit_width_selection.bit_width}.") + Logger.critical( + f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} " + f"to change their bit width to {manual_bit_width_selection.bit_width}.") for n in filtered_nodes: # check if a manual configuration exists for this node - if n in nodes_to_change_bit_width: + if n in unit_nodes_to_change_bit_width: Logger.info( - f"Node {n} has an existing manual bit width configuration of {nodes_to_change_bit_width.get(n)}. A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.") - nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width}) - return nodes_to_change_bit_width \ No newline at end of file + f"Node {n} has an existing manual bit width configuration of {unit_nodes_to_change_bit_width.get(n)}." + f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.") + unit_nodes_to_change_bit_width.update({n: manual_bit_width_selection.bit_width}) + return unit_nodes_to_change_bit_width + + def _construct_node_to_new_weights_bit_mapping(self, graph) -> Dict: + """ + Retrieve nodes from the graph that need their weights bit-widths changed according to the manual bit-width selections. + + Args: + graph (Graph): The graph containing the nodes to be filtered and manipulated. + + Returns: + Dict: A dictionary retrieved nodes from the graph. + """ + unit_nodes_to_change_bit_width = {} + + for manual_bit_width_selection in self.manual_weights_bit_width_selection_list: + filtered_nodes = graph.filter(manual_bit_width_selection.filter) + if len(filtered_nodes) == 0: + Logger.critical( + f"Node Filtering Error: No nodes found in the graph for filter {manual_bit_width_selection.filter.__dict__} " + f"to change their bit width to {manual_bit_width_selection.bit_width}.") + + for n in filtered_nodes: + attr_to_change_bit_width = [] + + attrs_str = n.get_node_weights_attributes() + if len(attrs_str) == 0: + Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.') + + attr = [] + for attr_str in attrs_str: + if isinstance(attr_str, str) and isinstance(manual_bit_width_selection.attr, str): + if attr_str.find(manual_bit_width_selection.attr) != -1: + attr.append(attr_str) + elif isinstance(attr_str, int) and isinstance(manual_bit_width_selection.attr, int): + if attr_str == manual_bit_width_selection.attr: + attr.append(attr_str) + if len(attr) == 0: + Logger.critical(f'The requested attribute {manual_bit_width_selection.attr} to change the bit width for {n} does not exist.') + + if n in unit_nodes_to_change_bit_width: + attr_to_change_bit_width = unit_nodes_to_change_bit_width[n] + for i, attr_to_bitwidth in enumerate(attr_to_change_bit_width): + if attr_to_bitwidth[1] == manual_bit_width_selection.attr: + del attr_to_change_bit_width[i] + Logger.info( + f"Node {n} has an existing manual bit width configuration of {manual_bit_width_selection.attr}." + f"A new manual configuration request of {manual_bit_width_selection.bit_width} has been received, and the previous value is being overridden.") + + attr_to_change_bit_width.append([manual_bit_width_selection.bit_width, manual_bit_width_selection.attr]) + unit_nodes_to_change_bit_width.update({n: attr_to_change_bit_width}) + + return unit_nodes_to_change_bit_width \ No newline at end of file diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index 7359cdf1c..5c79fe9a3 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -65,7 +65,7 @@ def set_quantization_configuration_to_graph(graph: Graph, Logger.warning("Using the HMSE error method for weights quantization parameters search. " "Note: This method may significantly increase runtime during the parameter search process.") - nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph) + nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph) for n in graph.nodes: set_quantization_configs_to_node(node=n, diff --git a/tests_pytest/common_tests/core/__init__.py b/tests_pytest/common_tests/core/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/common_tests/core/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/common_tests/core/common/__init__.py b/tests_pytest/common_tests/core/common/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/common_tests/core/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/common_tests/core/common/quantization/__init__.py b/tests_pytest/common_tests/core/common/quantization/__init__.py new file mode 100644 index 000000000..5397dea24 --- /dev/null +++ b/tests_pytest/common_tests/core/common/quantization/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/tests_pytest/common_tests/core/common/quantization/test_manual_bitwidth_selection.py b/tests_pytest/common_tests/core/common/quantization/test_manual_bitwidth_selection.py new file mode 100755 index 000000000..d5f97c3f3 --- /dev/null +++ b/tests_pytest/common_tests/core/common/quantization/test_manual_bitwidth_selection.py @@ -0,0 +1,225 @@ +import pytest + +from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter +from model_compression_toolkit.core.common.quantization.bit_width_config import ManualBitWidthSelection, ManualWeightsBitWidthSelection +from model_compression_toolkit.core import BitWidthConfig + +from model_compression_toolkit.core.common import Graph +from model_compression_toolkit.core.common.graph.edge import Edge +from tests_pytest._test_util.graph_builder_utils import build_node + + +TEST_KERNEL = 'kernel' +TEST_BIAS = 'bias' + +### dummy layer classes +class Conv2D: + pass +class InputLayer: + pass +class Add: + pass +class BatchNormalization: + pass +class ReLU: + pass +class Flatten: + pass +class Dense: + pass + + +### test model +def get_test_graph(): + n1 = build_node('input', layer_class=InputLayer) + conv1 = build_node('conv1', layer_class=Conv2D, canonical_weights={TEST_KERNEL: [1,2], TEST_BIAS: [3,4]}) + add1 = build_node('add1', layer_class=Add) + conv2 = build_node('conv2', layer_class=Conv2D) + bn1 = build_node('bn1', layer_class=BatchNormalization) + relu = build_node('relu1', layer_class=ReLU, canonical_weights={TEST_KERNEL: [1,2], TEST_BIAS: [3,4]}) + add2 = build_node('add2', layer_class=Add) + flatten = build_node('flatten', layer_class=Flatten) + fc = build_node('fc', layer_class=Dense) + + graph = Graph('g', input_nodes=[n1], + nodes=[conv1,add1, conv2, bn1, relu, add2, flatten], + output_nodes=[fc], + edge_list=[Edge(n1, conv1, 0, 0), + Edge(conv1, add1, 0, 0), + Edge(add1, conv2, 0, 0), + Edge(conv2, bn1, 0, 0), + Edge(bn1, relu, 0, 0), + Edge(relu, add2, 0, 0), + Edge(add1, add2, 0, 0), + Edge(add2, flatten, 0, 0), + Edge(flatten, fc, 0, 0), + ] + ) + return graph + + +class TestBitWidthConfig: + # test case for set_manual_activation_bit_width + test_input_0 = (None, None) + test_input_1 = (NodeTypeFilter(ReLU), 16) + test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16]) + test_input_3 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8]) + + test_expected_0 = ("The filters cannot be None.", None) + test_expected_1 = (NodeTypeFilter, ReLU, 16) + test_expected_2 = ([NodeTypeFilter, ReLU, 16], [NodeNameFilter, "conv1", 16]) + test_expected_3 = ([NodeTypeFilter, ReLU, 16], [NodeNameFilter, "conv1", 8]) + + @pytest.mark.parametrize(("inputs", "expected"), [ + (test_input_0, test_expected_0), + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + (test_input_3, test_expected_3), + ]) + def test_set_manual_activation_bit_width(self, inputs, expected): + def check_param_for_activation(mb_cfg, exp): + ### check setting config class (expected ManualBitWidthSelection) + assert type(mb_cfg) == ManualBitWidthSelection + + ### check setting filter for NodeFilter and NodeInfo + if mb_cfg.filter is not None: + assert isinstance(mb_cfg.filter, exp[0]) + + if isinstance(mb_cfg.filter, NodeTypeFilter): + assert mb_cfg.filter.node_type == exp[1] + elif isinstance(mb_cfg.filter, NodeNameFilter): + assert mb_cfg.filter.node_name == exp[1] + + ### check setting bit_width + assert mb_cfg.bit_width == exp[2] + else: + assert mb_cfg.filter is None + + manual_bit_cfg = BitWidthConfig() + try: + manual_bit_cfg.set_manual_activation_bit_width(inputs[0], inputs[1]) + ### check Activation + if len(manual_bit_cfg.manual_activation_bit_width_selection_list) == 1: + for a_mb_cfg in manual_bit_cfg.manual_activation_bit_width_selection_list: + print(a_mb_cfg, expected) + check_param_for_activation(a_mb_cfg, expected) + else: + for idx, a_mb_cfg in enumerate(manual_bit_cfg.manual_activation_bit_width_selection_list): + check_param_for_activation(a_mb_cfg, expected[idx]) + except Exception as e: + assert str(e) == expected[0] + + + # test case for set_manual_weights_bit_width + test_input_0 = (None, None, None) + test_input_1 = (NodeTypeFilter(ReLU), 16, TEST_KERNEL) + test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16], [TEST_KERNEL]) + test_input_3 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8], [TEST_KERNEL, TEST_BIAS]) + + test_expected_0 = ("The filters cannot be None.", None, None) + test_expected_1 = (NodeTypeFilter, ReLU, 16, TEST_KERNEL) + test_expected_2 = ([NodeTypeFilter, ReLU, 16, TEST_KERNEL], [NodeNameFilter, "conv1", 16, TEST_KERNEL]) + test_expected_3 = ([NodeTypeFilter, ReLU, 16, TEST_KERNEL], [NodeNameFilter, "conv1", 8, TEST_BIAS]) + + @pytest.mark.parametrize(("inputs", "expected"), [ + (test_input_0, test_expected_0), + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + (test_input_3, test_expected_3), + ]) + def test_set_manual_weights_bit_width(self, inputs, expected): + def check_param_weights(mb_cfg, exp): + ### check setting config class (expected ManualWeightsBitWidthSelection) + assert type(mb_cfg) == ManualWeightsBitWidthSelection + + ### check setting filter for NodeFilter and NodeInfo + if mb_cfg.filter is not None: + assert isinstance(mb_cfg.filter, exp[0]) + if isinstance(mb_cfg.filter, NodeTypeFilter): + assert mb_cfg.filter.node_type == exp[1] + elif isinstance(mb_cfg.filter, NodeNameFilter): + assert mb_cfg.filter.node_name == exp[1] + + ### check setting bit_width and attr + assert mb_cfg.bit_width == exp[2] + assert mb_cfg.attr == exp[3] + else: + assert mb_cfg.filter is None + + manual_bit_cfg = BitWidthConfig() + try: + manual_bit_cfg.set_manual_weights_bit_width(inputs[0], inputs[1], inputs[2]) + ### check weights + if len(manual_bit_cfg.manual_weights_bit_width_selection_list) == 1: + for a_mb_cfg in manual_bit_cfg.manual_weights_bit_width_selection_list: + print(a_mb_cfg, expected) + check_param_weights(a_mb_cfg, expected) + else: + for idx, a_mb_cfg in enumerate(manual_bit_cfg.manual_weights_bit_width_selection_list): + check_param_weights(a_mb_cfg, expected[idx]) + except Exception as e: + assert str(e) == expected[0] + + + # test case for get_nodes_to_manipulate_activation_bit_widths + test_input_0 = (NodeTypeFilter(ReLU), 16) + test_input_1 = (NodeNameFilter('relu1'), 16) + test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8]) + + test_expected_0 = ({"ReLU:relu1": 16}) + test_expected_1 = ({"ReLU:relu1": 16}) + test_expected_2 = ({"ReLU:relu1": 16, "Conv2D:conv1": 8}) + + @pytest.mark.parametrize(("inputs", "expected"), [ + (test_input_0, test_expected_0), + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + ]) + def test_get_nodes_to_manipulate_activation_bit_widths(self, inputs, expected): + fl_list = inputs[0] if isinstance(inputs[0], list) else [inputs[0]] + bw_list = inputs[1] if isinstance(inputs[1], list) else [inputs[1]] + + mbws_config = [] + for fl, bw in zip(fl_list, bw_list): + mbws_config.append(ManualBitWidthSelection(fl, bw)) + manual_bit_cfg = BitWidthConfig(manual_activation_bit_width_selection_list=mbws_config) + + graph = get_test_graph() + get_manual_bit_dict_activation = manual_bit_cfg.get_nodes_to_manipulate_activation_bit_widths(graph) + for idx, (key, val) in enumerate(get_manual_bit_dict_activation.items()): + assert str(key) == list(expected.keys())[idx] + assert val == list(expected.values())[idx] + + + # test case for get_nodes_to_manipulate_weights_bit_widths + test_input_0 = (NodeTypeFilter(ReLU), 16, TEST_KERNEL) + test_input_1 = (NodeNameFilter('relu1'), 16, TEST_BIAS) + test_input_2 = ([NodeTypeFilter(ReLU), NodeNameFilter("conv1")], [16, 8], [TEST_KERNEL, TEST_BIAS]) + test_input_3 = ([NodeNameFilter("conv1"), NodeNameFilter("conv1")], [4, 8], [TEST_KERNEL, TEST_BIAS]) + + test_expected_0 = ({"ReLU:relu1": [[16, TEST_KERNEL]]}) + test_expected_1 = ({"ReLU:relu1": [[16, TEST_BIAS]]}) + test_expected_2 = ({"ReLU:relu1": [[16, TEST_KERNEL]], "Conv2D:conv1": [[8, TEST_BIAS]]}) + test_expected_3 = ({"Conv2D:conv1": [[4, TEST_KERNEL], [8, TEST_BIAS]]}) + + @pytest.mark.parametrize(("inputs", "expected"), [ + (test_input_0, test_expected_0), + (test_input_1, test_expected_1), + (test_input_2, test_expected_2), + (test_input_3, test_expected_3), + ]) + def test_get_nodes_to_manipulate_weights_bit_widths(self, inputs, expected): + fl_list = inputs[0] if isinstance(inputs[0], list) else [inputs[0]] + bw_list = inputs[1] if isinstance(inputs[1], list) else [inputs[1]] + at_list = inputs[2] if isinstance(inputs[2], list) else [inputs[2]] + + manual_weights_bit_width_config = [] + for fl, bw, at in zip(fl_list, bw_list, at_list): + manual_weights_bit_width_config.append(ManualWeightsBitWidthSelection(fl, bw, at)) + manual_bit_cfg = BitWidthConfig(manual_weights_bit_width_selection_list=manual_weights_bit_width_config) + + graph = get_test_graph() + get_manual_bit_dict_weights = manual_bit_cfg.get_nodes_to_manipulate_weights_bit_widths(graph) + for idx, (key, val) in enumerate(get_manual_bit_dict_weights.items()): + assert str(key) == list(expected.keys())[idx] + assert val == list(expected.values())[idx]