Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def retrieve_preserved_quantization_node(self, node: BaseNode) -> BaseNode:
"""
while node.is_quantization_preserving():
prev_nodes = self.get_prev_nodes(node)
assert len(prev_nodes) == 1, "Activation preserving node should have only 1 input."
assert len(prev_nodes) == 1, f"Activation preserving node should have only 1 input, but node {node.name} has {len(prev_nodes)} inputs."
node = prev_nodes[0]
return node

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np
from pulp import *
from typing import Dict, Tuple, Any
from typing import Dict, Tuple, Any, List

from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def set_quantization_configuration_to_graph(graph: Graph,
nodes_to_manipulate_activation_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_activation_bit_widths(graph)
nodes_to_manipulate_weights_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_weights_bit_widths(graph)

for n in graph.nodes:
for n in graph.get_topo_sorted_nodes():
manual_bit_width_override = {ACTIVATION: nodes_to_manipulate_activation_bit_widths.get(n),
WEIGHTS: nodes_to_manipulate_weights_bit_widths.get(n)}
set_quantization_configs_to_node(node=n,
Expand Down Expand Up @@ -199,6 +199,16 @@ def set_quantization_configs_to_node(node: BaseNode,
if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
not node.get_has_activation():
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
elif candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT:
prev_nodes = graph.get_prev_nodes(node)
if len(prev_nodes) != 1:
# Preserving the quantization of more than 1 previous node is ambiguous, so disable it.
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because it has more than 1 input activations.")
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
elif not prev_nodes[0].is_quantization_preserving() or not prev_nodes[0].is_activation_quantization_enabled():
# Preserving the quantization of an unquantized node isn't possible, so disable it.
Logger.info(f"Disabling Quantization-Preserving for node {node.name} because previous node activation quantization is disabled.")
candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT


def create_node_activation_qc(qc: QuantizationConfig,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.graph.edge import Edge

from unittest.mock import Mock
from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc, DummyLayer
from model_compression_toolkit.core import FrameworkInfo
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configs_to_node
from model_compression_toolkit.core import QuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
OpQuantizationConfig, AttributeQuantizationConfig, Signedness
from mct_quantizers import QuantizationMethod


class TestSetNodeQuantizationConfig:

@staticmethod
def _get_op_config():
aqc = AttributeQuantizationConfig()
return OpQuantizationConfig(default_weight_attr_config=aqc,
attr_weights_configs_mapping={'w': aqc},
activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
activation_n_bits=7,
supported_input_activation_n_bits=7,
enable_activation_quantization=False,
quantization_preserving=True,
signedness=Signedness.AUTO)

def test_activation_preserving_with_2_inputs(self, fw_info_mock):
""" Tests that . """
n1 = build_node('in1_node')
n2 = build_node('in2_node')
n3 = build_node('qp_node')
n4 = build_node('qp2_node')
graph = Graph('g', input_nodes=[n1, n2], nodes=[n3], output_nodes=[n4],
edge_list=[Edge(n1, n3, 0, 0), Edge(n2, n3, 0, 0),
Edge(n3, n4, 0, 0)])

fqc = Mock(filterlayer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])},
layer2qco={DummyLayer: QuantizationConfigOptions(quantization_configurations=[self._get_op_config()])})
fw_info_mock = Mock(spec=FrameworkInfo, kernel_channels_mapping={DummyLayer: 0},
activation_quantizer_mapping={QuantizationMethod.POWER_OF_TWO: lambda x: 0},
get_kernel_op_attributes=lambda x: [None])
set_quantization_configs_to_node(n3, graph, QuantizationConfig(), fw_info_mock, fqc)
set_quantization_configs_to_node(n4, graph, QuantizationConfig(), fw_info_mock, fqc)
assert not n3.is_quantization_preserving() and not n3.is_activation_quantization_enabled()
assert not n4.is_quantization_preserving() and not n4.is_activation_quantization_enabled()

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import pytest

from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.graph.edge import Edge

Expand All @@ -35,3 +37,16 @@ def test_activation_preserving_candidate(self):
assert graph.retrieve_preserved_quantization_node(n3) is n1
assert graph.retrieve_preserved_quantization_node(n4) is n4
assert graph.retrieve_preserved_quantization_node(n5) is n4

def test_activation_preserving_disable_for_multi_input_node(self):
""" Tests that the retrieve_preserved_quantization_node raises an assertion error if node has more than 1 input. """
n1 = build_node('qact_node', qcs=[build_nbits_qc()])
n2 = build_node('qp1a_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
n3 = build_node('qact1b_node', qcs=[build_nbits_qc()])
n4 = build_node('qp2_node', qcs=[build_nbits_qc(a_enable=False, q_preserving=True)])
graph = Graph('g', input_nodes=[n1], nodes=[n2, n3], output_nodes=[n4],
edge_list=[Edge(n1, n2, 0, 0), Edge(n1, n3, 0, 0),
Edge(n2, n4, 0, 0), Edge(n2, n4, 0, 0)])

with pytest.raises(AssertionError, match="Activation preserving node should have only 1 input"):
graph.retrieve_preserved_quantization_node(n4)
Loading