Skip to content

Commit 4230a8b

Browse files
Apply activation quantization parameters selection(3rd PR internal review) (#19)
Fixed PR comments. - simplified a test. - removed unnecessary codes. - reverted about set_activation_quantization_param assert.
1 parent ca3508a commit 4230a8b

4 files changed

Lines changed: 127 additions & 240 deletions

File tree

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from __future__ import annotations
16-
1715
from typing import Any, List, Dict, TYPE_CHECKING
1816
from enum import Enum, auto
1917

@@ -25,7 +23,6 @@
2523
AttributeQuantizationConfig, OpQuantizationConfig
2624

2725
if TYPE_CHECKING:
28-
from model_compression_toolkit.core.common import BaseNode
2926
from model_compression_toolkit.core.common.graph.base_node import WeightAttrT
3027

3128
##########################################
@@ -114,17 +111,15 @@ def fln_quantization(self):
114111
return self.quant_mode == ActivationQuantizationMode.FLN_QUANT
115112

116113
def set_activation_quantization_param(self,
117-
activation_params: dict,
118-
node: BaseNode):
114+
activation_params: dict):
119115
"""
120116
Set a quantization parameter for the node's activation.
121117
122118
Args:
123119
activation_params: Dictionary that contains weight quantization params.
124-
node: node in a graph that represents the model.
125120
126121
"""
127-
assert node.is_activation_quantization_enabled() or node.is_fln_quantization()
122+
assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT
128123
for param_name, param_value in activation_params.items():
129124
self.activation_quantization_params[param_name] = param_value
130125

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,4 @@ def calculate_quantization_params(graph: Graph,
112112
node_prior_info=n.prior_info,
113113
out_stats_container=graph.get_out_stats_collector(n))
114114
# Create a NodeQuantizationConfig containing all quantization params and attach it to the node
115-
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params, n)
115+
candidate_qc.activation_quantization_cfg.set_activation_quantization_param(activation_params)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import pytest
16+
import numpy as np
17+
18+
from typing import Generator
19+
from unittest.mock import Mock
20+
from model_compression_toolkit.core.common import Graph, BaseNode
21+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \
22+
calculate_quantization_params
23+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
24+
CandidateNodeQuantizationConfig
25+
from model_compression_toolkit.core.common.quantization.node_quantization_config import \
26+
ActivationQuantizationMode, NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig
27+
from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig
28+
from model_compression_toolkit.core import QuantizationConfig
29+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
30+
from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector
31+
from mct_quantizers import QuantizationMethod
32+
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
33+
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
34+
35+
36+
class TestCalculateQuantizationParams:
37+
def build_op_cfg(self):
38+
op_cfg = Mock(spec=OpQuantizationConfig)
39+
op_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
40+
op_cfg.activation_n_bits = 16
41+
op_cfg.enable_activation_quantization = True
42+
op_cfg.quantization_preserving = False
43+
op_cfg.signedness = Signedness.AUTO
44+
45+
return op_cfg
46+
47+
def build_node(self, name='node', q_mode=ActivationQuantizationMode.QUANT):
48+
node = Mock(spec=BaseNode)
49+
node.name = name
50+
node.get_node_weights_attributes.return_value = []
51+
52+
if q_mode == ActivationQuantizationMode.QUANT:
53+
node.is_activation_quantization_enabled.return_value = True
54+
node.is_fln_quantization.return_value = False
55+
elif q_mode == ActivationQuantizationMode.FLN_QUANT:
56+
node.is_activation_quantization_enabled.return_value = False
57+
node.is_fln_quantization.return_value = True
58+
else:
59+
node.is_activation_quantization_enabled.return_value = False
60+
node.is_fln_quantization.return_value = False
61+
62+
activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=self.build_op_cfg())
63+
activation_quantization_cfg.set_qc(QuantizationConfig())
64+
activation_quantization_cfg.quant_mode = q_mode
65+
66+
candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig)
67+
candidate_quantization_config.activation_quantization_cfg = activation_quantization_cfg
68+
candidate_quantization_config.weights_quantization_cfg = Mock(spec=NodeWeightsQuantizationConfig)
69+
70+
node.candidates_quantization_cfg = [candidate_quantization_config]
71+
72+
return node
73+
74+
def get_test_graph(self, node_name, q_mode, data):
75+
node = self.build_node(node_name, q_mode=q_mode)
76+
graph = Graph('graph_name', input_nodes=[node], nodes=[node], output_nodes=[node], edge_list=[])
77+
78+
graph.node_to_out_stats_collector = dict()
79+
for n in graph.nodes():
80+
n.prior_info = NodePriorInfo()
81+
82+
graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=0)
83+
graph.node_to_out_stats_collector[n].hc._n_bins = 3
84+
graph.node_to_out_stats_collector[n].hc._bins = np.array(data)
85+
graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1])
86+
87+
return graph
88+
89+
### test pattern for ActivationQuantizationMode
90+
@pytest.mark.parametrize(["node_name", "q_mode", "input_data", "expects"], [
91+
# node_name, q_mode, input data, expected value
92+
['node_quant', ActivationQuantizationMode.QUANT, [0.4, 0.8, 1.2], [1.0, False]],
93+
['node_fln_quant', ActivationQuantizationMode.FLN_QUANT, [0.7, 1.4, 2.1], [2.0, False]],
94+
['node_fln_no_quant', ActivationQuantizationMode.FLN_NO_QUANT, [0.7, 1.4, 2.1], [None, None]],
95+
['node_no_quant', ActivationQuantizationMode.NO_QUANT, [0.7, 1.4, 2.1], [None, None]],
96+
['node_preserve_quant', ActivationQuantizationMode.PRESERVE_QUANT, [0.7, 1.4, 2.1], [None, None]],
97+
])
98+
def test_calculate_quantization_params_for_activation(self, node_name, q_mode, input_data, expects, mocker):
99+
"""
100+
Tests that calculate quantization params for activation quantization method.
101+
"""
102+
graph = self.get_test_graph(node_name, q_mode, input_data)
103+
104+
mocker.patch(
105+
'model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation._collect_nodes_for_hmse',
106+
return_value=[])
107+
108+
calculate_quantization_params(graph, Mock(spec=FrameworkImplementation), Mock(spec=Generator))
109+
110+
node = list(graph.nodes)[0]
111+
for candidate_qc in node.candidates_quantization_cfg:
112+
assert type(candidate_qc.activation_quantization_cfg.activation_quantization_params) == dict
113+
if expects[0] is not None:
114+
### QUANT or FLN_QUANT
115+
assert 'threshold' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
116+
assert 'is_signed' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
117+
118+
threshold = candidate_qc.activation_quantization_cfg.activation_quantization_params['threshold']
119+
is_signed = candidate_qc.activation_quantization_cfg.activation_quantization_params['is_signed']
120+
assert threshold == expects[0]
121+
assert is_signed == expects[1]
122+
else:
123+
assert 'threshold' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
124+
assert 'is_signed' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()

0 commit comments

Comments
 (0)