From aace81e7d42f3123e8b6fbcf5f7f33ebb2046512 Mon Sep 17 00:00:00 2001 From: gouda-youichi <123340875+gouda-youichi@users.noreply.github.com> Date: Wed, 2 Jul 2025 09:26:43 +0900 Subject: [PATCH 1/4] Apply activation quantization parameters selection (#16) Apply activation quantization parameters selection (#16) --- .../quantization/node_quantization_config.py | 9 +- .../qparams_computation.py | 4 +- .../test_calculate_quantization_params.py | 232 ++++++++++++++++++ 3 files changed, 241 insertions(+), 4 deletions(-) create mode 100755 tests_pytest/pytorch_tests/unit_tests/core/common/quantization/quantization_params_generation/test_calculate_quantization_params.py diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index c39223f01..6eda23d55 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from __future__ import annotations + from typing import Any, List, Dict, TYPE_CHECKING from enum import Enum, auto @@ -24,6 +26,7 @@ AttributeQuantizationConfig, OpQuantizationConfig if TYPE_CHECKING: + from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.core.common.graph.base_node import WeightAttrT ########################################## @@ -140,15 +143,17 @@ def fln_quantization(self): return self.quant_mode == ActivationQuantizationMode.FLN_QUANT def set_activation_quantization_param(self, - activation_params: dict): + activation_params: dict, + node: BaseNode): """ Set a quantization parameter for the node's activation. Args: activation_params: Dictionary that contains weight quantization params. + node: node in a graph that represents the model. """ - assert self.quant_mode == ActivationQuantizationMode.QUANT + assert node.is_activation_quantization_enabled() or node.is_fln_quantization() for param_name, param_value in activation_params.items(): self.activation_quantization_params[param_name] = param_value diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py index f97337b5d..c3d49d602 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py @@ -128,10 +128,10 @@ def calculate_quantization_params(graph: Graph, attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input) attr_cfg.set_weights_quantization_param(weights_params) - if n.is_activation_quantization_enabled(): + if n.is_activation_quantization_enabled() or n.is_fln_quantization(): # If node's activations should be quantized as well, we compute its activation quantization parameters activation_params = compute_activation_qparams( activation_quant_cfg=candidate_qc.activation_quantization_cfg, node_prior_info=n.prior_info, out_stats_container=graph.get_out_stats_collector(n)) # Create a NodeQuantizationConfig containing all quantization params and attach it to the node - candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params) + candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params, n) diff --git a/tests_pytest/pytorch_tests/unit_tests/core/common/quantization/quantization_params_generation/test_calculate_quantization_params.py b/tests_pytest/pytorch_tests/unit_tests/core/common/quantization/quantization_params_generation/test_calculate_quantization_params.py new file mode 100755 index 000000000..7ef44c599 --- /dev/null +++ b/tests_pytest/pytorch_tests/unit_tests/core/common/quantization/quantization_params_generation/test_calculate_quantization_params.py @@ -0,0 +1,232 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest +import numpy as np +import torch +from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ + calculate_quantization_params +from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \ + CandidateNodeQuantizationConfig, NodeQuantizationConfig +from model_compression_toolkit.core.common.quantization.node_quantization_config import \ + ActivationQuantizationMode, NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig +from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig +from model_compression_toolkit.core import QuantizationConfig, QuantizationErrorMethod +from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService +from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ + AttachTpcToPytorch +import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness, \ + AttributeQuantizationConfig +from model_compression_toolkit.core.pytorch.default_framework_info import PyTorchInfo +from model_compression_toolkit.core.common.framework_info import set_fw_info, get_fw_info + +from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation +from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector +from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, WEIGHTS_N_BITS +from mct_quantizers import QuantizationMethod + +from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping + +class TestCalculateQuantizationParams: + def get_op_qco(self): + # define a default quantization config for all non-specified weights attributes. + default_weight_attr_config = AttributeQuantizationConfig() + + # define a quantization config to quantize the kernel (for layers where there is a kernel attribute). + kernel_base_config = AttributeQuantizationConfig( + weights_n_bits=8, + weights_per_channel_threshold=True, + enable_weights_quantization=True) + + base_cfg = schema.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config}, + activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + supported_input_activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + signedness=Signedness.AUTO) + + default_config = schema.OpQuantizationConfig( + default_weight_attr_config=default_weight_attr_config, + attr_weights_configs_mapping={}, + activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + supported_input_activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + signedness=Signedness.AUTO + ) + + mx_cfg_list = [base_cfg] + for n in [8, 4, 2]: + mx_cfg_list.append(base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: n}})) + + return base_cfg, mx_cfg_list, default_config + + def generate_tpc_local(self, default_config, base_config, mixed_precision_cfg_list): + default_configuration_options = schema.QuantizationConfigOptions( + quantization_configurations=tuple([default_config])) + mixed_precision_configuration_options = schema.QuantizationConfigOptions( + quantization_configurations=tuple(mixed_precision_cfg_list), + base_config=base_config) + + operator_set = [] + + conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV, qc_options=mixed_precision_configuration_options) + relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU) + operator_set.extend([conv, relu]) + + generated_tpc = schema.TargetPlatformCapabilities( + default_qco=default_configuration_options, + operator_set=tuple(operator_set)) + + return generated_tpc + + def get_tpc(self): + base_cfg, mx_cfg_list, default_config = self.get_op_qco() + tpc = self.generate_tpc_local(default_config, base_cfg, mx_cfg_list) + return tpc + + def get_float_model(self): + class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) + self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) + self.conv3 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) + self.relu = torch.nn.ReLU() + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.relu(x) + x = self.conv3(x) + return x + + return BaseModel() + + def _create_weights_attr_quantization_config(self, weights_n_bits: int) -> AttributeQuantizationConfig: + weights_attr_config = AttributeQuantizationConfig(weights_n_bits=weights_n_bits) + return weights_attr_config + + def _create_node_weights_op_cfg(self, + def_weight_attr_config: AttributeQuantizationConfig) -> OpQuantizationConfig: + # define a quantization config to quantize the kernel (for layers where there is a kernel attribute). + kernel_base_config = AttributeQuantizationConfig( + weights_quantization_method=QuantizationMethod.SYMMETRIC, + enable_weights_quantization=False, + weights_n_bits=8) + + # define a quantization config to quantize the bias (for layers where there is a bias attribute). + bias_config = AttributeQuantizationConfig() + + attr_weights_configs_mapping = {'weight': kernel_base_config, 'bias': bias_config} + op_cfg = OpQuantizationConfig( + default_weight_attr_config=def_weight_attr_config, + attr_weights_configs_mapping=attr_weights_configs_mapping, + activation_quantization_method=QuantizationMethod.POWER_OF_TWO, + activation_n_bits=8, + supported_input_activation_n_bits=8, + enable_activation_quantization=True, + quantization_preserving=False, + signedness=Signedness.AUTO + ) + + return op_cfg + + def get_test_graph(self, qem: QuantizationErrorMethod): + float_model = self.get_float_model() + set_fw_info(PyTorchInfo) + + fw_impl = PytorchImplementation() + graph = fw_impl.model_reader(float_model, + self.representative_data_gen) + + quantization_config = QuantizationConfig(weights_error_method=qem) + + tpc = self.get_tpc() + attach2pytorch = AttachTpcToPytorch() + fqc = attach2pytorch.attach( + tpc, quantization_config.custom_tpc_opset_to_layer) + graph.set_fqc(fqc) + + def_weight_attr_config = self._create_weights_attr_quantization_config(weights_n_bits=8) + op_cfg = self._create_node_weights_op_cfg(def_weight_attr_config=def_weight_attr_config) + + graph.node_to_out_stats_collector = dict() + for id, n in enumerate(graph.nodes): + n.prior_info = fw_impl.get_node_prior_info(node=n, graph=graph) + + activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=op_cfg) + activation_quantization_cfg.set_qc(quantization_config) + weights_quantization_cfg = NodeWeightsQuantizationConfig(op_cfg=op_cfg, + weights_channels_axis=ChannelAxisMapping(0, 1), + node_attrs_list=['weight', 'bias']) + candidate_qc_a = CandidateNodeQuantizationConfig( + activation_quantization_cfg=activation_quantization_cfg, + weights_quantization_cfg=weights_quantization_cfg) + if n.name in ['conv3']: + candidate_qc_a.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT + else: + candidate_qc_a.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.QUANT + n.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=candidate_qc_a, candidates_quantization_cfg=[candidate_qc_a, candidate_qc_a]) + + graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=get_fw_info().out_channel_axis_mapping.get(n.type)) + graph.node_to_out_stats_collector[n].hc._n_bins = 3 + if n.name in ['conv1']: + graph.node_to_out_stats_collector[n].hc._bins = np.array([0.4, 0.8, 1.2]) + elif n.name in ['conv2']: + graph.node_to_out_stats_collector[n].hc._bins = np.array([0.7, 1.4, 2.1]) + elif n.name in ['conv3']: + graph.node_to_out_stats_collector[n].hc._bins = np.array([-32, -24, -1]) + elif n.name in ['relu']: + graph.node_to_out_stats_collector[n].hc._bins = np.array([2.0, 4.0, 6.0]) + else: + graph.node_to_out_stats_collector[n].hc._bins = np.array([0.1, 0.2, 0.3]) + graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1]) + + return graph, fw_impl + + def representative_data_gen(self, shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=10): + for _ in range(num_iter): + yield [torch.randn(batch_size, *shape)] * num_inputs + + def test_calculate_quantization_params(self): + graph, fw_impl = self.get_test_graph(QuantizationErrorMethod.MSE) + + calculate_quantization_params(graph, fw_impl, self.representative_data_gen) + + for node in graph.nodes: + for candidate_qc in node.candidates_quantization_cfg: + assert type(candidate_qc.activation_quantization_cfg.activation_quantization_params) == dict + assert 'threshold' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() + assert 'is_signed' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() + + threshold = candidate_qc.activation_quantization_cfg.activation_quantization_params['threshold'] + is_signed = candidate_qc.activation_quantization_cfg.activation_quantization_params['is_signed'] + if node.name in 'conv1': + assert threshold == 1.0 + assert is_signed == False + elif node.name in 'conv2': + assert threshold == 2.0 + assert is_signed == False + elif node.name in 'conv3': + assert threshold == 64.0 + assert is_signed == True + elif node.name in 'relu': + assert threshold == 16.0 + assert is_signed == False From 4230a8b1f1f26770bfdaca6a558ef3351877981d Mon Sep 17 00:00:00 2001 From: gouda-youichi <123340875+gouda-youichi@users.noreply.github.com> Date: Mon, 7 Jul 2025 12:07:30 +0900 Subject: [PATCH 2/4] Apply activation quantization parameters selection(3rd PR internal review) (#19) Fixed PR comments. - simplified a test. - removed unnecessary codes. - reverted about set_activation_quantization_param assert. --- .../quantization/node_quantization_config.py | 9 +- .../qparams_computation.py | 2 +- .../test_calculate_quantization_params.py | 124 ++++++++++ .../test_calculate_quantization_params.py | 232 ------------------ 4 files changed, 127 insertions(+), 240 deletions(-) create mode 100644 tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py delete mode 100755 tests_pytest/pytorch_tests/unit_tests/core/common/quantization/quantization_params_generation/test_calculate_quantization_params.py diff --git a/model_compression_toolkit/core/common/quantization/node_quantization_config.py b/model_compression_toolkit/core/common/quantization/node_quantization_config.py index 920e4e018..e1df6e4e8 100644 --- a/model_compression_toolkit/core/common/quantization/node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/node_quantization_config.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from __future__ import annotations - from typing import Any, List, Dict, TYPE_CHECKING from enum import Enum, auto @@ -25,7 +23,6 @@ AttributeQuantizationConfig, OpQuantizationConfig if TYPE_CHECKING: - from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.core.common.graph.base_node import WeightAttrT ########################################## @@ -114,17 +111,15 @@ def fln_quantization(self): return self.quant_mode == ActivationQuantizationMode.FLN_QUANT def set_activation_quantization_param(self, - activation_params: dict, - node: BaseNode): + activation_params: dict): """ Set a quantization parameter for the node's activation. Args: activation_params: Dictionary that contains weight quantization params. - node: node in a graph that represents the model. """ - assert node.is_activation_quantization_enabled() or node.is_fln_quantization() + assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT for param_name, param_value in activation_params.items(): self.activation_quantization_params[param_name] = param_value diff --git a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py index 4eb8baa5e..3ad6132d0 100644 --- a/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +++ b/model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py @@ -112,4 +112,4 @@ def calculate_quantization_params(graph: Graph, node_prior_info=n.prior_info, out_stats_container=graph.get_out_stats_collector(n)) # Create a NodeQuantizationConfig containing all quantization params and attach it to the node - candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params, n) + candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params) diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py new file mode 100644 index 000000000..3ea94fd60 --- /dev/null +++ b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py @@ -0,0 +1,124 @@ +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest +import numpy as np + +from typing import Generator +from unittest.mock import Mock +from model_compression_toolkit.core.common import Graph, BaseNode +from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ + calculate_quantization_params +from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \ + CandidateNodeQuantizationConfig +from model_compression_toolkit.core.common.quantization.node_quantization_config import \ + ActivationQuantizationMode, NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig +from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig +from model_compression_toolkit.core import QuantizationConfig +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness +from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector +from mct_quantizers import QuantizationMethod +from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo +from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation + + +class TestCalculateQuantizationParams: + def build_op_cfg(self): + op_cfg = Mock(spec=OpQuantizationConfig) + op_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO + op_cfg.activation_n_bits = 16 + op_cfg.enable_activation_quantization = True + op_cfg.quantization_preserving = False + op_cfg.signedness = Signedness.AUTO + + return op_cfg + + def build_node(self, name='node', q_mode=ActivationQuantizationMode.QUANT): + node = Mock(spec=BaseNode) + node.name = name + node.get_node_weights_attributes.return_value = [] + + if q_mode == ActivationQuantizationMode.QUANT: + node.is_activation_quantization_enabled.return_value = True + node.is_fln_quantization.return_value = False + elif q_mode == ActivationQuantizationMode.FLN_QUANT: + node.is_activation_quantization_enabled.return_value = False + node.is_fln_quantization.return_value = True + else: + node.is_activation_quantization_enabled.return_value = False + node.is_fln_quantization.return_value = False + + activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=self.build_op_cfg()) + activation_quantization_cfg.set_qc(QuantizationConfig()) + activation_quantization_cfg.quant_mode = q_mode + + candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig) + candidate_quantization_config.activation_quantization_cfg = activation_quantization_cfg + candidate_quantization_config.weights_quantization_cfg = Mock(spec=NodeWeightsQuantizationConfig) + + node.candidates_quantization_cfg = [candidate_quantization_config] + + return node + + def get_test_graph(self, node_name, q_mode, data): + node = self.build_node(node_name, q_mode=q_mode) + graph = Graph('graph_name', input_nodes=[node], nodes=[node], output_nodes=[node], edge_list=[]) + + graph.node_to_out_stats_collector = dict() + for n in graph.nodes(): + n.prior_info = NodePriorInfo() + + graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=0) + graph.node_to_out_stats_collector[n].hc._n_bins = 3 + graph.node_to_out_stats_collector[n].hc._bins = np.array(data) + graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1]) + + return graph + + ### test pattern for ActivationQuantizationMode + @pytest.mark.parametrize(["node_name", "q_mode", "input_data", "expects"], [ + # node_name, q_mode, input data, expected value + ['node_quant', ActivationQuantizationMode.QUANT, [0.4, 0.8, 1.2], [1.0, False]], + ['node_fln_quant', ActivationQuantizationMode.FLN_QUANT, [0.7, 1.4, 2.1], [2.0, False]], + ['node_fln_no_quant', ActivationQuantizationMode.FLN_NO_QUANT, [0.7, 1.4, 2.1], [None, None]], + ['node_no_quant', ActivationQuantizationMode.NO_QUANT, [0.7, 1.4, 2.1], [None, None]], + ['node_preserve_quant', ActivationQuantizationMode.PRESERVE_QUANT, [0.7, 1.4, 2.1], [None, None]], + ]) + def test_calculate_quantization_params_for_activation(self, node_name, q_mode, input_data, expects, mocker): + """ + Tests that calculate quantization params for activation quantization method. + """ + graph = self.get_test_graph(node_name, q_mode, input_data) + + mocker.patch( + 'model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation._collect_nodes_for_hmse', + return_value=[]) + + calculate_quantization_params(graph, Mock(spec=FrameworkImplementation), Mock(spec=Generator)) + + node = list(graph.nodes)[0] + for candidate_qc in node.candidates_quantization_cfg: + assert type(candidate_qc.activation_quantization_cfg.activation_quantization_params) == dict + if expects[0] is not None: + ### QUANT or FLN_QUANT + assert 'threshold' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() + assert 'is_signed' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() + + threshold = candidate_qc.activation_quantization_cfg.activation_quantization_params['threshold'] + is_signed = candidate_qc.activation_quantization_cfg.activation_quantization_params['is_signed'] + assert threshold == expects[0] + assert is_signed == expects[1] + else: + assert 'threshold' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() + assert 'is_signed' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() diff --git a/tests_pytest/pytorch_tests/unit_tests/core/common/quantization/quantization_params_generation/test_calculate_quantization_params.py b/tests_pytest/pytorch_tests/unit_tests/core/common/quantization/quantization_params_generation/test_calculate_quantization_params.py deleted file mode 100755 index 7ef44c599..000000000 --- a/tests_pytest/pytorch_tests/unit_tests/core/common/quantization/quantization_params_generation/test_calculate_quantization_params.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import pytest -import numpy as np -import torch -from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ - calculate_quantization_params -from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \ - CandidateNodeQuantizationConfig, NodeQuantizationConfig -from model_compression_toolkit.core.common.quantization.node_quantization_config import \ - ActivationQuantizationMode, NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig -from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig -from model_compression_toolkit.core import QuantizationConfig, QuantizationErrorMethod -from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService -from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \ - AttachTpcToPytorch -import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema -from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness, \ - AttributeQuantizationConfig -from model_compression_toolkit.core.pytorch.default_framework_info import PyTorchInfo -from model_compression_toolkit.core.common.framework_info import set_fw_info, get_fw_info - -from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation -from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector -from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, WEIGHTS_N_BITS -from mct_quantizers import QuantizationMethod - -from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping - -class TestCalculateQuantizationParams: - def get_op_qco(self): - # define a default quantization config for all non-specified weights attributes. - default_weight_attr_config = AttributeQuantizationConfig() - - # define a quantization config to quantize the kernel (for layers where there is a kernel attribute). - kernel_base_config = AttributeQuantizationConfig( - weights_n_bits=8, - weights_per_channel_threshold=True, - enable_weights_quantization=True) - - base_cfg = schema.OpQuantizationConfig( - default_weight_attr_config=default_weight_attr_config, - attr_weights_configs_mapping={KERNEL_ATTR: kernel_base_config}, - activation_quantization_method=QuantizationMethod.POWER_OF_TWO, - activation_n_bits=8, - supported_input_activation_n_bits=8, - enable_activation_quantization=True, - quantization_preserving=False, - signedness=Signedness.AUTO) - - default_config = schema.OpQuantizationConfig( - default_weight_attr_config=default_weight_attr_config, - attr_weights_configs_mapping={}, - activation_quantization_method=QuantizationMethod.POWER_OF_TWO, - activation_n_bits=8, - supported_input_activation_n_bits=8, - enable_activation_quantization=True, - quantization_preserving=False, - signedness=Signedness.AUTO - ) - - mx_cfg_list = [base_cfg] - for n in [8, 4, 2]: - mx_cfg_list.append(base_cfg.clone_and_edit(attr_to_edit={KERNEL_ATTR: {WEIGHTS_N_BITS: n}})) - - return base_cfg, mx_cfg_list, default_config - - def generate_tpc_local(self, default_config, base_config, mixed_precision_cfg_list): - default_configuration_options = schema.QuantizationConfigOptions( - quantization_configurations=tuple([default_config])) - mixed_precision_configuration_options = schema.QuantizationConfigOptions( - quantization_configurations=tuple(mixed_precision_cfg_list), - base_config=base_config) - - operator_set = [] - - conv = schema.OperatorsSet(name=schema.OperatorSetNames.CONV, qc_options=mixed_precision_configuration_options) - relu = schema.OperatorsSet(name=schema.OperatorSetNames.RELU) - operator_set.extend([conv, relu]) - - generated_tpc = schema.TargetPlatformCapabilities( - default_qco=default_configuration_options, - operator_set=tuple(operator_set)) - - return generated_tpc - - def get_tpc(self): - base_cfg, mx_cfg_list, default_config = self.get_op_qco() - tpc = self.generate_tpc_local(default_config, base_cfg, mx_cfg_list) - return tpc - - def get_float_model(self): - class BaseModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) - self.conv2 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) - self.conv3 = torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3) - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.conv1(x) - x = self.conv2(x) - x = self.relu(x) - x = self.conv3(x) - return x - - return BaseModel() - - def _create_weights_attr_quantization_config(self, weights_n_bits: int) -> AttributeQuantizationConfig: - weights_attr_config = AttributeQuantizationConfig(weights_n_bits=weights_n_bits) - return weights_attr_config - - def _create_node_weights_op_cfg(self, - def_weight_attr_config: AttributeQuantizationConfig) -> OpQuantizationConfig: - # define a quantization config to quantize the kernel (for layers where there is a kernel attribute). - kernel_base_config = AttributeQuantizationConfig( - weights_quantization_method=QuantizationMethod.SYMMETRIC, - enable_weights_quantization=False, - weights_n_bits=8) - - # define a quantization config to quantize the bias (for layers where there is a bias attribute). - bias_config = AttributeQuantizationConfig() - - attr_weights_configs_mapping = {'weight': kernel_base_config, 'bias': bias_config} - op_cfg = OpQuantizationConfig( - default_weight_attr_config=def_weight_attr_config, - attr_weights_configs_mapping=attr_weights_configs_mapping, - activation_quantization_method=QuantizationMethod.POWER_OF_TWO, - activation_n_bits=8, - supported_input_activation_n_bits=8, - enable_activation_quantization=True, - quantization_preserving=False, - signedness=Signedness.AUTO - ) - - return op_cfg - - def get_test_graph(self, qem: QuantizationErrorMethod): - float_model = self.get_float_model() - set_fw_info(PyTorchInfo) - - fw_impl = PytorchImplementation() - graph = fw_impl.model_reader(float_model, - self.representative_data_gen) - - quantization_config = QuantizationConfig(weights_error_method=qem) - - tpc = self.get_tpc() - attach2pytorch = AttachTpcToPytorch() - fqc = attach2pytorch.attach( - tpc, quantization_config.custom_tpc_opset_to_layer) - graph.set_fqc(fqc) - - def_weight_attr_config = self._create_weights_attr_quantization_config(weights_n_bits=8) - op_cfg = self._create_node_weights_op_cfg(def_weight_attr_config=def_weight_attr_config) - - graph.node_to_out_stats_collector = dict() - for id, n in enumerate(graph.nodes): - n.prior_info = fw_impl.get_node_prior_info(node=n, graph=graph) - - activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=op_cfg) - activation_quantization_cfg.set_qc(quantization_config) - weights_quantization_cfg = NodeWeightsQuantizationConfig(op_cfg=op_cfg, - weights_channels_axis=ChannelAxisMapping(0, 1), - node_attrs_list=['weight', 'bias']) - candidate_qc_a = CandidateNodeQuantizationConfig( - activation_quantization_cfg=activation_quantization_cfg, - weights_quantization_cfg=weights_quantization_cfg) - if n.name in ['conv3']: - candidate_qc_a.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT - else: - candidate_qc_a.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.QUANT - n.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=candidate_qc_a, candidates_quantization_cfg=[candidate_qc_a, candidate_qc_a]) - - graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=get_fw_info().out_channel_axis_mapping.get(n.type)) - graph.node_to_out_stats_collector[n].hc._n_bins = 3 - if n.name in ['conv1']: - graph.node_to_out_stats_collector[n].hc._bins = np.array([0.4, 0.8, 1.2]) - elif n.name in ['conv2']: - graph.node_to_out_stats_collector[n].hc._bins = np.array([0.7, 1.4, 2.1]) - elif n.name in ['conv3']: - graph.node_to_out_stats_collector[n].hc._bins = np.array([-32, -24, -1]) - elif n.name in ['relu']: - graph.node_to_out_stats_collector[n].hc._bins = np.array([2.0, 4.0, 6.0]) - else: - graph.node_to_out_stats_collector[n].hc._bins = np.array([0.1, 0.2, 0.3]) - graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1]) - - return graph, fw_impl - - def representative_data_gen(self, shape=(3, 8, 8), num_inputs=1, batch_size=2, num_iter=10): - for _ in range(num_iter): - yield [torch.randn(batch_size, *shape)] * num_inputs - - def test_calculate_quantization_params(self): - graph, fw_impl = self.get_test_graph(QuantizationErrorMethod.MSE) - - calculate_quantization_params(graph, fw_impl, self.representative_data_gen) - - for node in graph.nodes: - for candidate_qc in node.candidates_quantization_cfg: - assert type(candidate_qc.activation_quantization_cfg.activation_quantization_params) == dict - assert 'threshold' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() - assert 'is_signed' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() - - threshold = candidate_qc.activation_quantization_cfg.activation_quantization_params['threshold'] - is_signed = candidate_qc.activation_quantization_cfg.activation_quantization_params['is_signed'] - if node.name in 'conv1': - assert threshold == 1.0 - assert is_signed == False - elif node.name in 'conv2': - assert threshold == 2.0 - assert is_signed == False - elif node.name in 'conv3': - assert threshold == 64.0 - assert is_signed == True - elif node.name in 'relu': - assert threshold == 16.0 - assert is_signed == False From 07e0b0988d29a5653784e7e2703d86234dc236b7 Mon Sep 17 00:00:00 2001 From: gouda-youichi <123340875+gouda-youichi@users.noreply.github.com> Date: Mon, 7 Jul 2025 13:20:48 +0900 Subject: [PATCH 3/4] Apply activation quantization parameters selection (solved test failure.) (#20) solved test failure. --- .../test_calculate_quantization_params.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py index 3ea94fd60..d0a51b7f7 100644 --- a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py +++ b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py @@ -15,7 +15,6 @@ import pytest import numpy as np -from typing import Generator from unittest.mock import Mock from model_compression_toolkit.core.common import Graph, BaseNode from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ @@ -30,7 +29,6 @@ from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector from mct_quantizers import QuantizationMethod from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo -from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation class TestCalculateQuantizationParams: @@ -60,7 +58,6 @@ def build_node(self, name='node', q_mode=ActivationQuantizationMode.QUANT): node.is_fln_quantization.return_value = False activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=self.build_op_cfg()) - activation_quantization_cfg.set_qc(QuantizationConfig()) activation_quantization_cfg.quant_mode = q_mode candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig) @@ -84,7 +81,8 @@ def get_test_graph(self, node_name, q_mode, data): graph.node_to_out_stats_collector[n].hc._bins = np.array(data) graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1]) - return graph + quant_config = QuantizationConfig() + return graph, quant_config ### test pattern for ActivationQuantizationMode @pytest.mark.parametrize(["node_name", "q_mode", "input_data", "expects"], [ @@ -99,13 +97,9 @@ def test_calculate_quantization_params_for_activation(self, node_name, q_mode, i """ Tests that calculate quantization params for activation quantization method. """ - graph = self.get_test_graph(node_name, q_mode, input_data) + graph, quant_config = self.get_test_graph(node_name, q_mode, input_data) - mocker.patch( - 'model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation._collect_nodes_for_hmse', - return_value=[]) - - calculate_quantization_params(graph, Mock(spec=FrameworkImplementation), Mock(spec=Generator)) + calculate_quantization_params(graph, quant_config, Mock(), Mock()) node = list(graph.nodes)[0] for candidate_qc in node.candidates_quantization_cfg: From 46367cffc931fa470685a740d6fc195ce2a5d19a Mon Sep 17 00:00:00 2001 From: gouda-youichi <123340875+gouda-youichi@users.noreply.github.com> Date: Mon, 7 Jul 2025 18:14:29 +0900 Subject: [PATCH 4/4] Apply activation quantization parameters selection (modify for quantconfig and remove unnecessary mocker arguement. (#21)) modify for quantconfig and remove unnecessary mocker arguement. --- .../test_calculate_quantization_params.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py index d0a51b7f7..0a0292f67 100644 --- a/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py +++ b/tests_pytest/common_tests/unit_tests/core/quantization/quantization_params_generation/test_calculate_quantization_params.py @@ -81,8 +81,7 @@ def get_test_graph(self, node_name, q_mode, data): graph.node_to_out_stats_collector[n].hc._bins = np.array(data) graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1]) - quant_config = QuantizationConfig() - return graph, quant_config + return graph ### test pattern for ActivationQuantizationMode @pytest.mark.parametrize(["node_name", "q_mode", "input_data", "expects"], [ @@ -93,13 +92,13 @@ def get_test_graph(self, node_name, q_mode, data): ['node_no_quant', ActivationQuantizationMode.NO_QUANT, [0.7, 1.4, 2.1], [None, None]], ['node_preserve_quant', ActivationQuantizationMode.PRESERVE_QUANT, [0.7, 1.4, 2.1], [None, None]], ]) - def test_calculate_quantization_params_for_activation(self, node_name, q_mode, input_data, expects, mocker): + def test_calculate_quantization_params_for_activation(self, node_name, q_mode, input_data, expects): """ Tests that calculate quantization params for activation quantization method. """ - graph, quant_config = self.get_test_graph(node_name, q_mode, input_data) + graph = self.get_test_graph(node_name, q_mode, input_data) - calculate_quantization_params(graph, quant_config, Mock(), Mock()) + calculate_quantization_params(graph, QuantizationConfig(), Mock(), Mock()) node = list(graph.nodes)[0] for candidate_qc in node.candidates_quantization_cfg: