Skip to content

Commit fbd260e

Browse files
Apply activation quantization parameters selection.
Apply activation quantization parameters selection.
1 parent 47bf2dd commit fbd260e

File tree

3 files changed

+119
-2
lines changed

3 files changed

+119
-2
lines changed

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def set_activation_quantization_param(self,
119119
activation_params: Dictionary that contains weight quantization params.
120120
121121
"""
122-
assert self.quant_mode == ActivationQuantizationMode.QUANT
122+
assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT
123123
for param_name, param_value in activation_params.items():
124124
self.activation_quantization_params[param_name] = param_value
125125

model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def calculate_quantization_params(graph: Graph,
105105
attr_cfg.weights_channels_axis = ChannelAxisMapping(output_channels_axis, attr_cfg.weights_channels_axis.input)
106106
attr_cfg.set_weights_quantization_param(weights_params)
107107

108-
if n.is_activation_quantization_enabled():
108+
if n.is_activation_quantization_enabled() or n.is_fln_quantization():
109109
# If node's activations should be quantized as well, we compute its activation quantization parameters
110110
activation_params = compute_activation_qparams(quant_cfg=quant_cfg,
111111
node_activation_quant_cfg=candidate_qc.activation_quantization_cfg,
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import pytest
16+
import numpy as np
17+
18+
from unittest.mock import Mock
19+
from model_compression_toolkit.core.common import Graph, BaseNode
20+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.qparams_computation import \
21+
calculate_quantization_params
22+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
23+
CandidateNodeQuantizationConfig
24+
from model_compression_toolkit.core.common.quantization.node_quantization_config import \
25+
ActivationQuantizationMode, NodeActivationQuantizationConfig, NodeWeightsQuantizationConfig
26+
from model_compression_toolkit.target_platform_capabilities import OpQuantizationConfig
27+
from model_compression_toolkit.core import QuantizationConfig
28+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
29+
from model_compression_toolkit.core.common.collectors.statistics_collector import StatsCollector
30+
from mct_quantizers import QuantizationMethod
31+
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
32+
33+
34+
class TestCalculateQuantizationParams:
35+
def build_op_cfg(self):
36+
op_cfg = Mock(spec=OpQuantizationConfig)
37+
op_cfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
38+
op_cfg.activation_n_bits = 16
39+
op_cfg.enable_activation_quantization = True
40+
op_cfg.quantization_preserving = False
41+
op_cfg.signedness = Signedness.AUTO
42+
43+
return op_cfg
44+
45+
def build_node(self, name='node', q_mode=ActivationQuantizationMode.QUANT):
46+
node = Mock(spec=BaseNode)
47+
node.name = name
48+
node.get_node_weights_attributes.return_value = []
49+
50+
if q_mode == ActivationQuantizationMode.QUANT:
51+
node.is_activation_quantization_enabled.return_value = True
52+
node.is_fln_quantization.return_value = False
53+
elif q_mode == ActivationQuantizationMode.FLN_QUANT:
54+
node.is_activation_quantization_enabled.return_value = False
55+
node.is_fln_quantization.return_value = True
56+
else:
57+
node.is_activation_quantization_enabled.return_value = False
58+
node.is_fln_quantization.return_value = False
59+
60+
activation_quantization_cfg = NodeActivationQuantizationConfig(op_cfg=self.build_op_cfg())
61+
activation_quantization_cfg.quant_mode = q_mode
62+
63+
candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig)
64+
candidate_quantization_config.activation_quantization_cfg = activation_quantization_cfg
65+
candidate_quantization_config.weights_quantization_cfg = Mock(spec=NodeWeightsQuantizationConfig)
66+
67+
node.candidates_quantization_cfg = [candidate_quantization_config]
68+
69+
return node
70+
71+
def get_test_graph(self, node_name, q_mode, data):
72+
node = self.build_node(node_name, q_mode=q_mode)
73+
graph = Graph('graph_name', input_nodes=[node], nodes=[node], output_nodes=[node], edge_list=[])
74+
75+
graph.node_to_out_stats_collector = dict()
76+
for n in graph.nodes():
77+
n.prior_info = NodePriorInfo()
78+
79+
graph.node_to_out_stats_collector[n] = StatsCollector(init_min_value=0.0, init_max_value=1.0, out_channel_axis=0)
80+
graph.node_to_out_stats_collector[n].hc._n_bins = 3
81+
graph.node_to_out_stats_collector[n].hc._bins = np.array(data)
82+
graph.node_to_out_stats_collector[n].hc._counts = np.array([1, 1])
83+
84+
return graph
85+
86+
### test pattern for ActivationQuantizationMode
87+
@pytest.mark.parametrize(["node_name", "q_mode", "input_data", "expects"], [
88+
# node_name, q_mode, input data, expected value
89+
['node_quant', ActivationQuantizationMode.QUANT, [0.4, 0.8, 1.2], [1.0, False]],
90+
['node_fln_quant', ActivationQuantizationMode.FLN_QUANT, [0.7, 1.4, 2.1], [2.0, False]],
91+
['node_fln_no_quant', ActivationQuantizationMode.FLN_NO_QUANT, [0.7, 1.4, 2.1], [None, None]],
92+
['node_no_quant', ActivationQuantizationMode.NO_QUANT, [0.7, 1.4, 2.1], [None, None]],
93+
['node_preserve_quant', ActivationQuantizationMode.PRESERVE_QUANT, [0.7, 1.4, 2.1], [None, None]],
94+
])
95+
def test_calculate_quantization_params_for_activation(self, node_name, q_mode, input_data, expects):
96+
"""
97+
Tests that calculate quantization params for activation quantization method.
98+
"""
99+
graph = self.get_test_graph(node_name, q_mode, input_data)
100+
101+
calculate_quantization_params(graph, QuantizationConfig(), Mock(), Mock())
102+
103+
node = list(graph.nodes)[0]
104+
for candidate_qc in node.candidates_quantization_cfg:
105+
assert type(candidate_qc.activation_quantization_cfg.activation_quantization_params) == dict
106+
if expects[0] is not None:
107+
### QUANT or FLN_QUANT
108+
assert 'threshold' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
109+
assert 'is_signed' in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
110+
111+
threshold = candidate_qc.activation_quantization_cfg.activation_quantization_params['threshold']
112+
is_signed = candidate_qc.activation_quantization_cfg.activation_quantization_params['is_signed']
113+
assert threshold == expects[0]
114+
assert is_signed == expects[1]
115+
else:
116+
assert 'threshold' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()
117+
assert 'is_signed' not in candidate_qc.activation_quantization_cfg.activation_quantization_params.keys()

0 commit comments

Comments
 (0)