diff --git a/model_compression_toolkit/core/common/graph/base_graph.py b/model_compression_toolkit/core/common/graph/base_graph.py index 26b449f19..e16cf34eb 100644 --- a/model_compression_toolkit/core/common/graph/base_graph.py +++ b/model_compression_toolkit/core/common/graph/base_graph.py @@ -754,7 +754,7 @@ def retrieve_preserved_quantization_node(self, node: BaseNode) -> BaseNode: """ while node.is_quantization_preserving(): prev_nodes = self.get_prev_nodes(node) - assert len(prev_nodes) == 1, "Activation preserving node should have only 1 input." + assert len(prev_nodes) == 1, f"Activation preserving node should have only 1 input, but node {node.name} has {len(prev_nodes)} inputs." node = prev_nodes[0] return node diff --git a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py index 7faf91772..bb9f26bce 100644 --- a/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +++ b/model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py @@ -16,7 +16,7 @@ import numpy as np from pulp import * -from typing import Dict, Tuple, Any +from typing import Dict, Tuple, Any, List from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index 31fe10110..810023322 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -67,7 +67,7 @@ def set_quantization_configuration_to_graph(graph: Graph, nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph) nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph) - for n in graph.nodes: + for n in graph.get_topo_sorted_nodes(): manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n), WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)} set_quantization_configs_to_node(node=n, @@ -199,6 +199,16 @@ def set_quantization_configs_to_node(node: BaseNode, if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \ not node.get_has_activation(): candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT + elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT: + prev_nodes = graph.get_prev_nodes(node) + if len(prev_nodes) != 1: + # Preserving the quantization of more than 1 previous node is ambiguous, so disable it. + Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.") + candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT + elif not prev_nodes[0].is_quantization_preserving() or not prev_nodes[0].is_activation_quantization_enabled(): + # Preserving the quantization of an unquantized node isn't possible, so disable it. + Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.") + candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT def create_node_activation_qc(qc: QuantizationConfig, diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 9e5166bbd..4c54343d8 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -870,7 +870,7 @@ def test_invalid_bit_width_selection(self): ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 3).run_test() # Check that the correct exception message was raised self.assertEqual(str(context.exception), - "Manually selected activation bit-width 3 is invalid for node Add:add2.") + "Manually selected activation bit-width 3 is invalid for node Add:add1.") with self.assertRaises(Exception) as context: ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 3).run_test() @@ -880,7 +880,7 @@ def test_invalid_bit_width_selection(self): def test_mul_16_bit_manual_selection(self): """ - This test checks the execptions in the manual bit-width selection feature. + This test checks the exceptions in the manual bit-width selection feature. """ # This "mul" can be configured to 16 bit Manual16BitWidthSelectionTest(self, NodeNameFilter('mul1'), 16).run_test() diff --git a/tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py b/tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py new file mode 100644 index 000000000..3bbea4d73 --- /dev/null +++ b/tests_pytest/common_tests/unit_tests/core/graph/test_node_quantization.py @@ -0,0 +1,61 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from model_compression_toolkit.core.common import Graph +from model_compression_toolkit.core.common.graph.edge import Edge + +from unittest.mock import Mock +from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc, DummyLayer +from model_compression_toolkit.core import FrameworkInfo +from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configs_to_node +from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \ + OpQuantizationConfig, AttributeQuantizationConfig, Signedness +from mct_quantizers import QuantizationMethod + + +class TestSetNodeQuantizationConfig: + + @staticmethod + def _get_op_config(): + aqc = AttributeQuantizationConfig() + return OpQuantizationConfig(default_weight_attr_config=aqc, + attr_weights_configs_mapping={'w': aqc}, + activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + activation_n_bits=7, + supported_input_activation_n_bits=7, + enable_activation_quantization=False, + quantization_preserving=True, + signedness=Signedness.AUTO) + + def test_activation_preserving_with_2_inputs(self, fw_info_mock): + """ Tests that . """ + n1 = build_node('in1_node') + n2 = build_node('in2_node') + n3 = build_node('qp_node') + n4 = build_node('qp2_node') + graph = Graph('g', input_nodes=[n1, n2], nodes=[n3], output_nodes=[n4], + edge_list=[Edge(n1, n3, 0, 0), Edge(n2, n3, 0, 0), + Edge(n3, n4, 0, 0)]) + + fqc = Mock(filterlayer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])}, + layer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])}) + fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0}, + activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0}, + get_kernel_op_attributes=lambda x: [None]) + set_quantization_configs_to_node(n3, graph, QuantizationConfig(), fw_info_mock, fqc) + set_quantization_configs_to_node(n4, graph, QuantizationConfig(), fw_info_mock, fqc) + assert not n3.is_quantization_preserving() and not n3.is_activation_quantization_enabled() + assert not n4.is_quantization_preserving() and not n4.is_activation_quantization_enabled() + diff --git a/tests_pytest/common_tests/unit_tests/core/graph/test_quantization_preserving_node.py b/tests_pytest/common_tests/unit_tests/core/graph/test_quantization_preserving_node.py index ab6bc37a3..2e41dca0c 100644 --- a/tests_pytest/common_tests/unit_tests/core/graph/test_quantization_preserving_node.py +++ b/tests_pytest/common_tests/unit_tests/core/graph/test_quantization_preserving_node.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import pytest + from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.graph.edge import Edge @@ -35,3 +37,16 @@ def test_activation_preserving_candidate(self): assert graph.retrieve_preserved_quantization_node(n3) is n1 assert graph.retrieve_preserved_quantization_node(n4) is n4 assert graph.retrieve_preserved_quantization_node(n5) is n4 + + def test_activation_preserving_disable_for_multi_input_node(self): + """ Tests that the retrieve_preserved_quantization_node raises an assertion error if node has more than 1 input. """ + n1 = build_node('qact_node', qcs=[build_nbits_qc()]) + n2 = build_node('qp1a_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)]) + n3 = build_node('qact1b_node', qcs=[build_nbits_qc()]) + n4 = build_node('qp2_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)]) + graph = Graph('g', input_nodes=[n1], nodes=[n2, n3], output_nodes=[n4], + edge_list=[Edge(n1, n2, 0, 0), Edge(n1, n3, 0, 0), + Edge(n2, n4, 0, 0), Edge(n2, n4, 0, 0)]) + + with pytest.raises(AssertionError, match="Activation preserving node should have only 1 input"): + graph.retrieve_preserved_quantization_node(n4)