Skip to content

Commit 0dbb760

Browse files
authored
Handle quantization preserving with 2 inputs (#1421)
Handle quantization preserving nodes that have more than 1 input, or come after an unquantized node. For these nodes, quantization preserving is disabled.
1 parent 3e78fce commit 0dbb760

File tree

6 files changed

+91
-5
lines changed

6 files changed

+91
-5
lines changed

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def retrieve_preserved_quantization_node(self, node: BaseNode) -> BaseNode:
754754
"""
755755
while node.is_quantization_preserving():
756756
prev_nodes = self.get_prev_nodes(node)
757-
assert len(prev_nodes) == 1, "Activation preserving node should have only 1 input."
757+
assert len(prev_nodes) == 1, f"Activation preserving node should have only 1 input, but node {node.name} has {len(prev_nodes)} inputs."
758758
node = prev_nodes[0]
759759
return node
760760

model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import numpy as np
1818
from pulp import *
19-
from typing import Dict, Tuple, Any
19+
from typing import Dict, Tuple, Any, List
2020

2121
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget
2222

model_compression_toolkit/core/common/quantization/set_node_quantization_config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def set_quantization_configuration_to_graph(graph: Graph,
6767
nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
6868
nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)
6969

70-
for n in graph.nodes:
70+
for n in graph.get_topo_sorted_nodes():
7171
manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
7272
WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
7373
set_quantization_configs_to_node(node=n,
@@ -199,6 +199,16 @@ def set_quantization_configs_to_node(node: BaseNode,
199199
if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
200200
not node.get_has_activation():
201201
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
202+
elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT:
203+
prev_nodes = graph.get_prev_nodes(node)
204+
if len(prev_nodes) != 1:
205+
# Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
206+
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
207+
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
208+
elif not prev_nodes[0].is_quantization_preserving() or not prev_nodes[0].is_activation_quantization_enabled():
209+
# Preserving the quantization of an unquantized node isn't possible, so disable it.
210+
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
211+
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
202212

203213

204214
def create_node_activation_qc(qc: QuantizationConfig,

tests/keras_tests/feature_networks_tests/test_features_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def test_invalid_bit_width_selection(self):
870870
ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 3).run_test()
871871
# Check that the correct exception message was raised
872872
self.assertEqual(str(context.exception),
873-
"Manually selected activation bit-width 3 is invalid for node Add:add2.")
873+
"Manually selected activation bit-width 3 is invalid for node Add:add1.")
874874

875875
with self.assertRaises(Exception) as context:
876876
ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 3).run_test()
@@ -880,7 +880,7 @@ def test_invalid_bit_width_selection(self):
880880

881881
def test_mul_16_bit_manual_selection(self):
882882
"""
883-
This test checks the execptions in the manual bit-width selection feature.
883+
This test checks the exceptions in the manual bit-width selection feature.
884884
"""
885885
# This "mul" can be configured to 16 bit
886886
Manual16BitWidthSelectionTest(self, NodeNameFilter('mul1'), 16).run_test()
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from model_compression_toolkit.core.common import Graph
16+
from model_compression_toolkit.core.common.graph.edge import Edge
17+
18+
from unittest.mock import Mock
19+
from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc, DummyLayer
20+
from model_compression_toolkit.core import FrameworkInfo
21+
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configs_to_node
22+
from model_compression_toolkit.core import QuantizationConfig
23+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
24+
OpQuantizationConfig, AttributeQuantizationConfig, Signedness
25+
from mct_quantizers import QuantizationMethod
26+
27+
28+
class TestSetNodeQuantizationConfig:
29+
30+
@staticmethod
31+
def _get_op_config():
32+
aqc = AttributeQuantizationConfig()
33+
return OpQuantizationConfig(default_weight_attr_config=aqc,
34+
attr_weights_configs_mapping={'w': aqc},
35+
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
36+
activation_n_bits=7,
37+
supported_input_activation_n_bits=7,
38+
enable_activation_quantization=False,
39+
quantization_preserving=True,
40+
signedness=Signedness.AUTO)
41+
42+
def test_activation_preserving_with_2_inputs(self, fw_info_mock):
43+
""" Tests that . """
44+
n1 = build_node('in1_node')
45+
n2 = build_node('in2_node')
46+
n3 = build_node('qp_node')
47+
n4 = build_node('qp2_node')
48+
graph = Graph('g', input_nodes=[n1, n2], nodes=[n3], output_nodes=[n4],
49+
edge_list=[Edge(n1, n3, 0, 0), Edge(n2, n3, 0, 0),
50+
Edge(n3, n4, 0, 0)])
51+
52+
fqc = Mock(filterlayer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])},
53+
layer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])})
54+
fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0},
55+
activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0},
56+
get_kernel_op_attributes=lambda x: [None])
57+
set_quantization_configs_to_node(n3, graph, QuantizationConfig(), fw_info_mock, fqc)
58+
set_quantization_configs_to_node(n4, graph, QuantizationConfig(), fw_info_mock, fqc)
59+
assert not n3.is_quantization_preserving() and not n3.is_activation_quantization_enabled()
60+
assert not n4.is_quantization_preserving() and not n4.is_activation_quantization_enabled()
61+

tests_pytest/common_tests/unit_tests/core/graph/test_quantization_preserving_node.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
import pytest
16+
1517
from model_compression_toolkit.core.common import Graph
1618
from model_compression_toolkit.core.common.graph.edge import Edge
1719

@@ -35,3 +37,16 @@ def test_activation_preserving_candidate(self):
3537
assert graph.retrieve_preserved_quantization_node(n3) is n1
3638
assert graph.retrieve_preserved_quantization_node(n4) is n4
3739
assert graph.retrieve_preserved_quantization_node(n5) is n4
40+
41+
def test_activation_preserving_disable_for_multi_input_node(self):
42+
""" Tests that the retrieve_preserved_quantization_node raises an assertion error if node has more than 1 input. """
43+
n1 = build_node('qact_node', qcs=[build_nbits_qc()])
44+
n2 = build_node('qp1a_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
45+
n3 = build_node('qact1b_node', qcs=[build_nbits_qc()])
46+
n4 = build_node('qp2_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
47+
graph = Graph('g', input_nodes=[n1], nodes=[n2, n3], output_nodes=[n4],
48+
edge_list=[Edge(n1, n2, 0, 0), Edge(n1, n3, 0, 0),
49+
Edge(n2, n4, 0, 0), Edge(n2, n4, 0, 0)])
50+
51+
with pytest.raises(AssertionError, match="Activation preserving node should have only 1 input"):
52+
graph.retrieve_preserved_quantization_node(n4)

0 commit comments

Comments
 (0)