Skip to content

Commit 0d678a5

Browse files
gouda-youichiKazunoriSumiyakkawa14
authored
Apply fusinginfo quantization config to Activation quantization config. (#1467)
Apply fusinginfo quantization config to Activation quantization config. (#1467) --------- Co-authored-by: KazunoriSumiya <Sumiya.kazunori@jp.panasonic.com> Co-authored-by: kawasaki.kenta <kawasaki.kenta@miraxia.com>
1 parent 60b9bc6 commit 0d678a5

11 files changed

Lines changed: 286 additions & 25 deletions

File tree

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@
3232
from model_compression_toolkit.core.common.collectors.statistics_collector import scale_statistics, shift_statistics
3333
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
3434
from model_compression_toolkit.core.common.user_info import UserInformation
35-
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
35+
from model_compression_toolkit.core.common.quantization.node_quantization_config import \
36+
NodeActivationQuantizationConfig, ActivationQuantizationMode
3637
from model_compression_toolkit.logger import Logger
3738
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
3839
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
3940
FrameworkQuantizationCapabilities
4041

41-
4242
def validate_graph_after_change(method: Callable) -> Callable:
4343
"""
4444
Decorator for graph-mutating methods. After the decorated method executes,
@@ -876,15 +876,32 @@ def _find_intermediate_and_exit_nodes(self, entry_node: BaseNode, fw_impl: Any)
876876

877877
return intermediate_nodes, next_node
878878

879-
def disable_fused_nodes_activation_quantization(self):
879+
def override_fused_node_activation_quantization_candidates(self):
880880
"""
881-
Disable activation quantization for all nodes in fused operations,
881+
Override fused node activation quantization candidates for all nodes in fused operations,
882882
except for the last node in each fused group.
883-
"""
884-
nodes_to_disable = self.fusing_info.get_inner_fln_nodes()
885-
for node in nodes_to_disable:
886-
for qc in node.candidates_quantization_cfg:
887-
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
883+
Update the value of quantization_config with the value of op_quaitization_cfg from FusingInfo.
884+
"""
885+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
886+
887+
nodes_in_fln = self.fusing_info.get_inner_fln_nodes()
888+
for node in nodes_in_fln:
889+
fused_node_op_id = self.fusing_info.get_fused_op_id_for_node(node.name)
890+
fusiong_op_quaitization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
891+
org_candidate = node.candidates_quantization_cfg[0]
892+
if fusiong_op_quaitization_cfg is not None and fusiong_op_quaitization_cfg.enable_activation_quantization:
893+
# Set ActivationQuantizationMode to FLN_QUANT and update the value of quantization_config
894+
activation_quantization_cfg = NodeActivationQuantizationConfig(qc=org_candidate,
895+
op_cfg=fusiong_op_quaitization_cfg,
896+
activation_quantization_fn=org_candidate.activation_quantization_cfg.activation_quantization_fn,
897+
activation_quantization_params_fn=org_candidate.activation_quantization_cfg.activation_quantization_params_fn)
898+
activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
899+
for qc in node.candidates_quantization_cfg:
900+
qc.activation_quantization_cfg = activation_quantization_cfg
901+
else:
902+
# Set ActivationQuantizationMode to FLN_NO_QUANT
903+
for qc in node.candidates_quantization_cfg:
904+
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_NO_QUANT
888905

889906
def validate(self):
890907
"""
@@ -908,4 +925,4 @@ def remove_edge(self, *args, **kwargs):
908925
"""
909926
Wrap networkx functions (that modifies the graph) with our validate decorator.
910927
"""
911-
return super().remove_edge(*args, **kwargs)
928+
return super().remove_edge(*args, **kwargs)

model_compression_toolkit/core/common/graph/base_node.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,19 +216,31 @@ def is_activation_quantization_enabled(self) -> bool:
216216
Returns: Whether node activation quantization is enabled or not.
217217
"""
218218
return self._is_single_quant_mode(ActivationQuantizationMode.QUANT)
219-
220-
def is_fln_quantization(self) -> bool:
219+
220+
def is_fln_no_quantization(self) -> bool:
221221
"""
222-
Returns: Whether the node's activation quantization is FLN
222+
Returns: Whether node is FLN no quantization.
223223
"""
224-
return self._is_single_quant_mode(ActivationQuantizationMode.FLN_QUANT)
225-
224+
return self._is_single_quant_mode(ActivationQuantizationMode.FLN_NO_QUANT)
225+
226226
def is_quantization_preserving(self) -> bool:
227227
"""
228228
Returns: Whether node activation quantization information is preserved from its inputs.
229229
"""
230230
return self._is_single_quant_mode(ActivationQuantizationMode.PRESERVE_QUANT)
231231

232+
def is_no_quantization(self) -> bool:
233+
"""
234+
Returns: Whether node is no quantization.
235+
"""
236+
return self._is_single_quant_mode(ActivationQuantizationMode.NO_QUANT)
237+
238+
def is_fln_quantization(self) -> bool:
239+
"""
240+
Returns: Whether the node's activation quantization is FLN
241+
"""
242+
return self._is_single_quant_mode(ActivationQuantizationMode.FLN_QUANT)
243+
232244
def is_weights_quantization_enabled(self, attr_name: str) -> bool:
233245
"""
234246
Checks whether a node's weights attribute quantization is enabled.

model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
2222
CandidateNodeQuantizationConfig
2323

24-
2524
def filter_nodes_candidates(graph: Graph):
2625
"""
2726
Filters the graph's nodes candidates configuration list.
@@ -87,7 +86,7 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
8786
filtered_candidates = copy.deepcopy(node.candidates_quantization_cfg)
8887
final_candidates = copy.deepcopy(node.candidates_quantization_cfg)
8988

90-
if (node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr)) and not node.is_activation_quantization_enabled():
89+
if (node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr)) and node.is_no_quantization():
9190
# If activation quantization is disabled and the node doesn't have a kernel or doesn't quantize the kernel,
9291
# but for some reason the node has multiple candidates then replace it with a single dummy candidate with
9392
# default bit-width values.
@@ -102,9 +101,10 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
102101

103102
final_candidates = [single_dummy_candidate]
104103

105-
elif not node.is_activation_quantization_enabled():
104+
elif node.is_no_quantization():
106105
# Remove candidates that have duplicated weights candidates for node with disabled activation quantization.
107106
# Replacing the activation n_bits in the remained configurations with default value to prevent confusion.
107+
# Set the config of the non-quantized FLN node to POWER_OF_TWO.
108108
seen_candidates = set()
109109
filtered_candidates = [candidate for candidate in filtered_candidates if
110110
candidate.weights_quantization_cfg not in seen_candidates
@@ -116,6 +116,14 @@ def filter_node_candidates(node: BaseNode) -> List[CandidateNodeQuantizationConf
116116

117117
final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
118118

119+
elif node.is_fln_no_quantization() or node.is_fln_quantization():
120+
# Remove candidates that have duplicated weights candidates for node with disabled activation quantization.
121+
seen_candidates = set()
122+
filtered_candidates = [candidate for candidate in filtered_candidates if
123+
candidate.weights_quantization_cfg not in seen_candidates
124+
and not seen_candidates.add(candidate.weights_quantization_cfg)]
125+
final_candidates = _filter_bit_method_dups(filtered_candidates, node.kernel_attr)
126+
119127
elif node.kernel_attr is None or not node.is_weights_quantization_enabled(node.kernel_attr):
120128
# TODO:
121129
# To allow MP on positional weights we need to modify this to consider all weights not only kernel.

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class ActivationQuantizationMode(Enum):
4747
FLN_QUANT = auto()
4848
PRESERVE_QUANT = auto()
4949
NO_QUANT = auto()
50-
50+
FLN_NO_QUANT = auto()
5151

5252
class BaseNodeQuantizationConfig(object):
5353
"""

model_compression_toolkit/core/graph_prep_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def get_finalized_graph(initial_graph: Graph,
155155
######################################
156156
fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(transformed_graph)
157157
transformed_graph.fusing_info = fusing_info
158-
transformed_graph.disable_fused_nodes_activation_quantization()
158+
transformed_graph.override_fused_node_activation_quantization_candidates()
159159

160160
######################################
161161
# Channel equalization

tests/keras_tests/function_tests/test_activation_weights_composition_substitution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def prepare_graph(in_model, keras_impl, mixed_precision_candidates_list, base_co
124124

125125
fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(graph)
126126
graph.fusing_info = fusing_info
127-
graph.disable_fused_nodes_activation_quantization()
127+
graph.override_fused_node_activation_quantization_candidates()
128128

129129
graph = filter_nodes_candidates(graph)
130130

@@ -230,7 +230,7 @@ def test_two_conv_net_compose_after_split_activation_only(self):
230230

231231
graph.skip_validation_check = False
232232

233-
self._verify_two_conv_with_split_test(graph, v_graph, 3, 3)
233+
self._verify_two_conv_with_split_test(graph, v_graph, 9, 3)
234234

235235
def test_all_weights_layers_composition(self):
236236
in_model = multiple_weights_nodes_model()

tests/keras_tests/function_tests/test_cfg_candidates_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def prepare_graph(in_model, base_config, default_config, bitwidth_candidates):
6060

6161
fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(graph)
6262
graph.fusing_info = fusing_info
63-
graph.disable_fused_nodes_activation_quantization()
63+
graph.override_fused_node_activation_quantization_candidates()
6464

6565
return graph
6666

tests_pytest/_fw_tests_common_base/fusing/base_graph_with_fusing_metadata_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_disable_act_quantization(self, graph_with_fusion_metadata: Graph):
117117
for qc in node.candidates_quantization_cfg:
118118
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.QUANT
119119

120-
graph_with_fusion_metadata.disable_fused_nodes_activation_quantization()
120+
graph_with_fusion_metadata.override_fused_node_activation_quantization_candidates()
121121
disabled_nodes = [
122122
node.name for node in graph_with_fusion_metadata.nodes
123123
if all(not qc.activation_quantization_cfg.enable_activation_quantization
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from copy import deepcopy
16+
17+
import pytest
18+
from unittest.mock import Mock
19+
20+
from mct_quantizers import QuantizationMethod
21+
from model_compression_toolkit.core.common import Graph
22+
from model_compression_toolkit.core.common.graph.base_node import BaseNode
23+
from model_compression_toolkit.core.common.fusion.fusing_info import FusingInfo
24+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
25+
from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc
26+
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode, NodeActivationQuantizationConfig
27+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
28+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import power_of_two_selection_histogram
29+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import symmetric_selection_histogram
30+
from model_compression_toolkit.core import QuantizationErrorMethod
31+
32+
def build_mock_fusing_info(nodes, idx):
33+
"""
34+
Creates a mock FusingInfo object that simulates the behavior of fusing information in a graph.
35+
"""
36+
37+
OpQCfg = Mock(spec=NodeActivationQuantizationConfig)
38+
OpQCfg.activation_n_bits = 16
39+
OpQCfg.signedness = Signedness.AUTO
40+
OpQCfg.activation_quantization_method = QuantizationMethod.POWER_OF_TWO
41+
OpQCfg.activation_quantization_params_fn = power_of_two_selection_histogram
42+
OpQCfg.quantization_preserving = False
43+
44+
fusing_info = Mock(spec=FusingInfo)
45+
fusing_info.get_inner_fln_nodes.return_value = [nodes[0], nodes[1]]
46+
47+
if idx == 1:
48+
OpQCfg.enable_activation_quantization = True
49+
fusing_info.get_fused_op_quantization_config.side_effect = [OpQCfg, None]
50+
elif idx == 2:
51+
fusing_info.get_fused_op_quantization_config.side_effect = [None, None]
52+
else:
53+
OpQCfg.enable_activation_quantization = False
54+
fusing_info.get_fused_op_quantization_config.side_effect = [OpQCfg, None]
55+
56+
return fusing_info
57+
58+
def build_mock_node(name, layer_class):
59+
"""
60+
Creates mock nodes representing a simple neural network structure.
61+
"""
62+
node = Mock(spec=BaseNode)
63+
node.name = name
64+
node.layer_class = layer_class
65+
66+
activation_quantization_cfg = Mock(spec=NodeActivationQuantizationConfig)
67+
activation_quantization_cfg.quant_mode = Mock()
68+
activation_quantization_cfg.activation_quantization_fn = symmetric_selection_histogram
69+
activation_quantization_cfg.activation_quantization_params_fn = power_of_two_selection_histogram
70+
71+
candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig)
72+
candidate_quantization_config.activation_quantization_cfg = activation_quantization_cfg
73+
candidate_quantization_config.activation_error_method = QuantizationErrorMethod.MSE
74+
candidate_quantization_config.relu_bound_to_power_of_2 = 0
75+
candidate_quantization_config.activation_channel_equalization = False
76+
candidate_quantization_config.input_scaling = False
77+
candidate_quantization_config.min_threshold = 0
78+
candidate_quantization_config.l_p_value = 0
79+
candidate_quantization_config.shift_negative_activation_correction = 0
80+
candidate_quantization_config.z_threshold = 0
81+
candidate_quantization_config.shift_negative_ratio = 0
82+
candidate_quantization_config.shift_negative_threshold_recalculation = 0
83+
candidate_quantization_config.concat_threshold_update = 0
84+
candidate_quantization_config.weights_quantization_cfg = 0
85+
86+
node.candidates_quantization_cfg = [candidate_quantization_config]
87+
88+
return node
89+
90+
91+
class TestGraph:
92+
93+
@pytest.mark.parametrize(("idx"), [
94+
1,
95+
2,
96+
3,
97+
])
98+
def test_override_fused_node_activation_quantization_candidates(self, idx):
99+
"""
100+
Test the override_fused_node_activation_quantization_candidates function for a graph with multiple nodes and configurations.
101+
"""
102+
### Create Test Nodes
103+
mock_nodes = []
104+
mock_nodes.append(build_mock_node(name='conv', layer_class='Conv2d'))
105+
mock_nodes.append(build_mock_node(name='fc', layer_class='Linear'))
106+
107+
### Create a mock graph
108+
### Note: Generate the graph first because fusing_info cannot be set without it.
109+
### In the following Mock, use wraps to mock everything except fusing_info.
110+
real_graph = Graph("dummy", [], [], [], [])
111+
real_graph.fusing_info = build_mock_fusing_info(mock_nodes, idx)
112+
113+
graph = Mock(spec=Graph, wraps=real_graph)
114+
graph.nodes = mock_nodes
115+
116+
### call override_fused_node_activation_quantization_candidates
117+
graph.override_fused_node_activation_quantization_candidates()
118+
119+
### Check if the ActivationQuantization settings set on the graph nodes match the expected values
120+
nodes = list(graph.nodes)
121+
122+
if idx == 1:
123+
### Check if the first node ActivationQuantization settings match the expected values
124+
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_QUANT
125+
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits == 16
126+
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.signedness == Signedness.AUTO
127+
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
128+
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params_fn == power_of_two_selection_histogram
129+
### Check if the second node ActivationQuantization settings match the expected values
130+
assert nodes[1].candidates_quantization_cfg[0].activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_NO_QUANT
131+
else:
132+
### Check if the first node ActivationQuantization settings match the expected values
133+
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_NO_QUANT

tests_pytest/common_tests/unit_tests/core/mixed_precision/resource_utilization_tools/test_resource_utilization_calculator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def test_compute_cuts_random_fusion_valid_utilization(self, seed, disable_quanti
591591
graph.fusing_info = fusing_info
592592

593593
if disable_quantization:
594-
graph.disable_fused_nodes_activation_quantization()
594+
graph.override_fused_node_activation_quantization_candidates()
595595

596596
graph.find_node_by_name = MethodType(Graph.find_node_by_name, graph)
597597

0 commit comments

Comments
 (0)