|
| 1 | +# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +import pytest |
| 16 | +import numpy as np |
| 17 | + |
| 18 | +from unittest.mock import Mock |
| 19 | +from model_compression_toolkit.core.common import Graph, BaseNode |
| 20 | +from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \ |
| 21 | + calculate_quantization_params |
| 22 | +from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \ |
| 23 | + CandidateNodeQuantizationConfig |
| 24 | +from model_compression_toolkit.core.common.quantization.node_quantization_config import \ |
| 25 | + ActivationQuantizationMode, NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig |
| 26 | +from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig |
| 27 | +from model_compression_toolkit.core import QuantizationConfig |
| 28 | +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness |
| 29 | +from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector |
| 30 | +from mct_quantizers import QuantizationMethod |
| 31 | +from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo |
| 32 | + |
| 33 | + |
| 34 | +class TestCalculateQuantizationParams: |
| 35 | + def build_op_cfg(self): |
| 36 | + op_cfg = Mock(spec=OpQuantizationConfig) |
| 37 | + op_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO |
| 38 | + op_cfg.activation_n_bits = 16 |
| 39 | + op_cfg.enable_activation_quantization = True |
| 40 | + op_cfg.quantization_preserving = False |
| 41 | + op_cfg.signedness = Signedness.AUTO |
| 42 | + |
| 43 | + return op_cfg |
| 44 | + |
| 45 | + def build_node(self, name='node', q_mode=ActivationQuantizationMode.QUANT): |
| 46 | + node = Mock(spec=BaseNode) |
| 47 | + node.name = name |
| 48 | + node.get_node_weights_attributes.return_value = [] |
| 49 | + |
| 50 | + if q_mode == ActivationQuantizationMode.QUANT: |
| 51 | + node.is_activation_quantization_enabled.return_value = True |
| 52 | + node.is_fln_quantization.return_value = False |
| 53 | + elif q_mode == ActivationQuantizationMode.FLN_QUANT: |
| 54 | + node.is_activation_quantization_enabled.return_value = False |
| 55 | + node.is_fln_quantization.return_value = True |
| 56 | + else: |
| 57 | + node.is_activation_quantization_enabled.return_value = False |
| 58 | + node.is_fln_quantization.return_value = False |
| 59 | + |
| 60 | + activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=self.build_op_cfg()) |
| 61 | + activation_quantization_cfg.quant_mode = q_mode |
| 62 | + |
| 63 | + candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig) |
| 64 | + candidate_quantization_config.activation_quantization_cfg = activation_quantization_cfg |
| 65 | + candidate_quantization_config.weights_quantization_cfg = Mock(spec=NodeWeightsQuantizationConfig) |
| 66 | + |
| 67 | + node.candidates_quantization_cfg = [candidate_quantization_config] |
| 68 | + |
| 69 | + return node |
| 70 | + |
| 71 | + def get_test_graph(self, node_name, q_mode, data): |
| 72 | + node = self.build_node(node_name, q_mode=q_mode) |
| 73 | + graph = Graph('graph_name', input_nodes=[node], nodes=[node], output_nodes=[node], edge_list=[]) |
| 74 | + |
| 75 | + graph.node_to_out_stats_collector = dict() |
| 76 | + for n in graph.nodes(): |
| 77 | + n.prior_info = NodePriorInfo() |
| 78 | + |
| 79 | + graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=0) |
| 80 | + graph.node_to_out_stats_collector[n].hc._n_bins = 3 |
| 81 | + graph.node_to_out_stats_collector[n].hc._bins = np.array(data) |
| 82 | + graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1]) |
| 83 | + |
| 84 | + return graph |
| 85 | + |
| 86 | + ### test pattern for ActivationQuantizationMode |
| 87 | + @pytest.mark.parametrize(["node_name", "q_mode", "input_data", "expects"], [ |
| 88 | + # node_name, q_mode, input data, expected value |
| 89 | + ['node_quant', ActivationQuantizationMode.QUANT, [0.4, 0.8, 1.2], [1.0, False]], |
| 90 | + ['node_fln_quant', ActivationQuantizationMode.FLN_QUANT, [0.7, 1.4, 2.1], [2.0, False]], |
| 91 | + ['node_fln_no_quant', ActivationQuantizationMode.FLN_NO_QUANT, [0.7, 1.4, 2.1], [None, None]], |
| 92 | + ['node_no_quant', ActivationQuantizationMode.NO_QUANT, [0.7, 1.4, 2.1], [None, None]], |
| 93 | + ['node_preserve_quant', ActivationQuantizationMode.PRESERVE_QUANT, [0.7, 1.4, 2.1], [None, None]], |
| 94 | + ]) |
| 95 | + def test_calculate_quantization_params_for_activation(self, node_name, q_mode, input_data, expects): |
| 96 | + """ |
| 97 | + Tests that calculate quantization params for activation quantization method. |
| 98 | + """ |
| 99 | + graph = self.get_test_graph(node_name, q_mode, input_data) |
| 100 | + |
| 101 | + calculate_quantization_params(graph, QuantizationConfig(), Mock(), Mock()) |
| 102 | + |
| 103 | + node = list(graph.nodes)[0] |
| 104 | + for candidate_qc in node.candidates_quantization_cfg: |
| 105 | + assert type(candidate_qc.activation_quantization_cfg.activation_quantization_params) == dict |
| 106 | + if expects[0] is not None: |
| 107 | + ### QUANT or FLN_QUANT |
| 108 | + assert 'threshold' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() |
| 109 | + assert 'is_signed' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() |
| 110 | + |
| 111 | + threshold = candidate_qc.activation_quantization_cfg.activation_quantization_params['threshold'] |
| 112 | + is_signed = candidate_qc.activation_quantization_cfg.activation_quantization_params['is_signed'] |
| 113 | + assert threshold == expects[0] |
| 114 | + assert is_signed == expects[1] |
| 115 | + else: |
| 116 | + assert 'threshold' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() |
| 117 | + assert 'is_signed' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys() |
0 commit comments