Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def retrieve_preserved_quantization_node(self, node: BaseNode) -> BaseNode:
"""
while node.is_quantization_preserving():
prev_nodes = self.get_prev_nodes(node)
assert len(prev_nodes) == 1, "Activation preserving node should have only 1 input."
assert len(prev_nodes) == 1, f"Activation preserving node should have only 1 input, but node {node.name} has {len(prev_nodes)} inputs."
node = prev_nodes[0]
return node

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np
from pulp import *
from typing import Dict, Tuple, Any
from typing import Dict, Tuple, Any, List

from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def set_quantization_configuration_to_graph(graph: Graph,
nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)

for n in graph.nodes:
for n in graph.get_topo_sorted_nodes():
manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
set_quantization_configs_to_node(node=n,
Expand Down Expand Up @@ -199,6 +199,16 @@ def set_quantization_configs_to_node(node: BaseNode,
if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
not node.get_has_activation():
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT:
prev_nodes = graph.get_prev_nodes(node)
if len(prev_nodes) != 1:
# Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
elif not prev_nodes[0].is_quantization_preserving() or not prev_nodes[0].is_activation_quantization_enabled():
# Preserving the quantization of an unquantized node isn't possible, so disable it.
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT


def create_node_activation_qc(qc: QuantizationConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def test_invalid_bit_width_selection(self):
ManualBitWidthSelectionTest(self, NodeTypeFilter(layers.Add), 3).run_test()
# Check that the correct exception message was raised
self.assertEqual(str(context.exception),
"Manually selected activation bit-width 3 is invalid for node Add:add2.")
"Manually selected activation bit-width 3 is invalid for node Add:add1.")

with self.assertRaises(Exception) as context:
ManualBitWidthSelectionTest(self, NodeNameFilter('relu1'), 3).run_test()
Expand All @@ -880,7 +880,7 @@ def test_invalid_bit_width_selection(self):

def test_mul_16_bit_manual_selection(self):
"""
This test checks the execptions in the manual bit-width selection feature.
This test checks the exceptions in the manual bit-width selection feature.
"""
# This "mul" can be configured to 16 bit
Manual16BitWidthSelectionTest(self, NodeNameFilter('mul1'), 16).run_test()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.graph.edge import Edge

from unittest.mock import Mock
from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc, DummyLayer
from model_compression_toolkit.core import FrameworkInfo
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configs_to_node
from model_compression_toolkit.core import QuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
OpQuantizationConfig, AttributeQuantizationConfig, Signedness
from mct_quantizers import QuantizationMethod


class TestSetNodeQuantizationConfig:

@staticmethod
def _get_op_config():
aqc = AttributeQuantizationConfig()
return OpQuantizationConfig(default_weight_attr_config=aqc,
attr_weights_configs_mapping={'w': aqc},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=7,
supported_input_activation_n_bits=7,
enable_activation_quantization=False,
quantization_preserving=True,
signedness=Signedness.AUTO)

def test_activation_preserving_with_2_inputs(self, fw_info_mock):
""" Tests that . """
n1 = build_node('in1_node')
n2 = build_node('in2_node')
n3 = build_node('qp_node')
n4 = build_node('qp2_node')
graph = Graph('g', input_nodes=[n1, n2], nodes=[n3], output_nodes=[n4],
edge_list=[Edge(n1, n3, 0, 0), Edge(n2, n3, 0, 0),
Edge(n3, n4, 0, 0)])

fqc = Mock(filterlayer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])},
layer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])})
fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0},
activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0},
get_kernel_op_attributes=lambda x: [None])
set_quantization_configs_to_node(n3, graph, QuantizationConfig(), fw_info_mock, fqc)
set_quantization_configs_to_node(n4, graph, QuantizationConfig(), fw_info_mock, fqc)
assert not n3.is_quantization_preserving() and not n3.is_activation_quantization_enabled()
assert not n4.is_quantization_preserving() and not n4.is_activation_quantization_enabled()

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest

from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.graph.edge import Edge

Expand All @@ -35,3 +37,16 @@ def test_activation_preserving_candidate(self):
assert graph.retrieve_preserved_quantization_node(n3) is n1
assert graph.retrieve_preserved_quantization_node(n4) is n4
assert graph.retrieve_preserved_quantization_node(n5) is n4

def test_activation_preserving_disable_for_multi_input_node(self):
""" Tests that the retrieve_preserved_quantization_node raises an assertion error if node has more than 1 input. """
n1 = build_node('qact_node', qcs=[build_nbits_qc()])
n2 = build_node('qp1a_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
n3 = build_node('qact1b_node', qcs=[build_nbits_qc()])
n4 = build_node('qp2_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
graph = Graph('g', input_nodes=[n1], nodes=[n2, n3], output_nodes=[n4],
edge_list=[Edge(n1, n2, 0, 0), Edge(n1, n3, 0, 0),
Edge(n2, n4, 0, 0), Edge(n2, n4, 0, 0)])

with pytest.raises(AssertionError, match="Activation preserving node should have only 1 input"):
graph.retrieve_preserved_quantization_node(n4)