Skip to content

Commit 4bb8be9

Browse files
irenabirenab
authored andcommitted
align setting fln op config with the new code
1 parent b3f690a commit 4bb8be9

4 files changed

Lines changed: 95 additions & 56 deletions

File tree

model_compression_toolkit/core/common/graph/base_graph.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -860,32 +860,36 @@ def _find_intermediate_and_exit_nodes(self, entry_node: BaseNode, fw_impl: Any)
860860

861861
return intermediate_nodes, next_node
862862

863+
# TODO irena move to load_fqc and clean up tests (currently tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py)
863864
def override_fused_node_activation_quantization_candidates(self):
864865
"""
865866
Override fused node activation quantization candidates for all nodes in fused operations,
866867
except for the last node in each fused group.
867868
Update the value of quantization_config with the value of op_quaitization_cfg from FusingInfo.
868869
"""
869-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
870-
871870
nodes_in_fln = self.fusing_info.get_inner_fln_nodes()
872871
for node in nodes_in_fln:
873872
fused_node_op_id = self.fusing_info.get_fused_op_id_for_node(node.name)
874-
fusiong_op_quaitization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
875-
org_candidate = node.candidates_quantization_cfg[0]
876-
if fusiong_op_quaitization_cfg is not None and fusiong_op_quaitization_cfg.enable_activation_quantization:
877-
# Set ActivationQuantizationMode to FLN_QUANT and update the value of quantization_config
878-
activation_quantization_cfg = NodeActivationQuantizationConfig(qc=org_candidate,
879-
op_cfg=fusiong_op_quaitization_cfg,
880-
activation_quantization_fn=org_candidate.activation_quantization_cfg.activation_quantization_fn,
881-
activation_quantization_params_fn=org_candidate.activation_quantization_cfg.activation_quantization_params_fn)
882-
activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
883-
for qc in node.candidates_quantization_cfg:
884-
qc.activation_quantization_cfg = activation_quantization_cfg
873+
fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id)
874+
if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization:
875+
def update(qc):
876+
qc.activation_quantization_cfg = NodeActivationQuantizationConfig(
877+
fusing_op_quantization_cfg,
878+
qc.activation_quantization_cfg.activation_quantization_fn,
879+
qc.activation_quantization_cfg.activation_quantization_params_fn
880+
)
881+
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
882+
node.quantization_cfg.update_all(update)
883+
node.quantization_cfg.remove_duplicates()
885884
else:
886-
# Set ActivationQuantizationMode to FLN_NO_QUANT
885+
node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.FLN_NO_QUANT)
886+
# Remove duplicate candidates. We cannot compare whole candidates since activation configs might not
887+
# be identical, but we do want to treat them as such. So we only check duplication by weight configs.
888+
uniq_qcs = []
887889
for qc in node.candidates_quantization_cfg:
888-
qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_NO_QUANT
890+
if not any(qc.weights_quantization_cfg == uqc.weights_quantization_cfg for uqc in uniq_qcs):
891+
uniq_qcs.append(qc)
892+
node.quantization_cfg.candidates_quantization_cfg = uniq_qcs
889893

890894
def validate(self):
891895
"""

model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ def get_activation_quant_mode(self) -> ActivationQuantizationMode:
8989
self._validate_consistent_activation_quant_mode()
9090
return self.base_quantization_cfg.activation_quantization_cfg.quant_mode
9191

92+
def remove_duplicates(self):
93+
"""
94+
Remove duplicate candidates. First candidate among duplicates is kept, and the order is preserved.
95+
"""
96+
uniq_qcs = []
97+
for qc in self.candidates_quantization_cfg:
98+
if qc not in uniq_qcs:
99+
uniq_qcs.append(qc)
100+
self.candidates_quantization_cfg = uniq_qcs
101+
92102
def __post_init__(self, validate=True):
93103
if validate:
94104
if not any(self.base_quantization_cfg == qc for qc in self.candidates_quantization_cfg):

model_compression_toolkit/quantization_preparation/load_fqc.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,22 @@ def _set_nodes_quantization_configuration(graph: Graph,
139139
return graph
140140

141141

142-
def _set_fusion_info(graph, fqc) -> Graph:
142+
def _set_fusion_info(graph: Graph, fqc: FrameworkQuantizationCapabilities) -> Graph:
143+
"""
144+
145+
Args:
146+
graph: graph.
147+
fqc: quantization capabilities with attached framework.
148+
149+
Returns:
150+
151+
"""
143152
# TODO fix the dict with const keys inside get_fusing_patterns. use named tuple or class
153+
# TODO irena instead of storing fusion inside graph (including tpc objects) and then let graph convert tpc op config to
154+
# node config, do it here and only store in graph whatever is relevant after this stage.
144155
fusing_info = FusingInfoGenerator(fqc.get_fusing_patterns()).generate_fusing_info(graph)
145156
graph.fusing_info = fusing_info
146-
graph.disable_fused_nodes_activation_quantization()
157+
graph.override_fused_node_activation_quantization_candidates()
147158
return graph
148159

149160

tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15+
import itertools
1516
from copy import deepcopy
1617

1718
import pytest
18-
from unittest.mock import Mock
19+
from unittest.mock import Mock, PropertyMock
1920

2021
from mct_quantizers import QuantizationMethod
2122
from model_compression_toolkit.core.common import Graph
@@ -24,10 +25,13 @@
2425
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
2526
from tests.common_tests.helpers.generate_test_tpc import generate_test_attr_configs, generate_test_op_qc
2627
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode, NodeActivationQuantizationConfig
27-
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig
28+
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
29+
CandidateNodeQuantizationConfig, NodeQuantizationConfig
2830
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import power_of_two_selection_histogram
2931
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import symmetric_selection_histogram
3032
from model_compression_toolkit.core import QuantizationErrorMethod
33+
from tests_pytest._test_util.graph_builder_utils import build_node, build_nbits_qc
34+
3135

3236
def build_mock_fusing_info(nodes, idx):
3337
"""
@@ -55,36 +59,26 @@ def build_mock_fusing_info(nodes, idx):
5559

5660
return fusing_info
5761

58-
def build_mock_node(name, layer_class):
62+
def build_mock_node(name, layer_class, w_cfgs):
5963
"""
6064
Creates mock nodes representing a simple neural network structure.
6165
"""
62-
node = Mock(spec=BaseNode)
63-
node.name = name
64-
node.layer_class = layer_class
65-
66-
activation_quantization_cfg = Mock(spec=NodeActivationQuantizationConfig)
67-
activation_quantization_cfg.quant_mode = Mock()
68-
activation_quantization_cfg.activation_quantization_fn = symmetric_selection_histogram
69-
activation_quantization_cfg.activation_quantization_params_fn = power_of_two_selection_histogram
70-
71-
candidate_quantization_config = Mock(spec=CandidateNodeQuantizationConfig)
72-
candidate_quantization_config.activation_quantization_cfg = activation_quantization_cfg
73-
candidate_quantization_config.activation_error_method = QuantizationErrorMethod.MSE
74-
candidate_quantization_config.relu_bound_to_power_of_2 = 0
75-
candidate_quantization_config.activation_channel_equalization = False
76-
candidate_quantization_config.input_scaling = False
77-
candidate_quantization_config.min_threshold = 0
78-
candidate_quantization_config.l_p_value = 0
79-
candidate_quantization_config.shift_negative_activation_correction = 0
80-
candidate_quantization_config.z_threshold = 0
81-
candidate_quantization_config.shift_negative_ratio = 0
82-
candidate_quantization_config.shift_negative_threshold_recalculation = 0
83-
candidate_quantization_config.concat_threshold_update = 0
84-
candidate_quantization_config.weights_quantization_cfg = 0
85-
86-
node.candidates_quantization_cfg = [candidate_quantization_config]
66+
node = build_node(name, layer_class=layer_class)
67+
68+
def eq(self_, other):
69+
return self_.activation_n_bits == other.activation_n_bits and self_._quant_mode == other.quant_mode
70+
a_cfgs = [Mock(spec=NodeActivationQuantizationConfig,
71+
quant_mode=Mock(),
72+
activation_n_bits=b,
73+
activation_quantization_fn=symmetric_selection_histogram,
74+
activation_quantization_params_fn=power_of_two_selection_histogram,
75+
__eq__=eq) for b in [5, 6]]
76+
77+
qcs = [CandidateNodeQuantizationConfig(a_cfg, w_cfg) for a_cfg, w_cfg in itertools.product(a_cfgs, w_cfgs)]
8778

79+
node.quantization_cfg = NodeQuantizationConfig(base_quantization_cfg=qcs[0],
80+
candidates_quantization_cfg=qcs,
81+
validate=False)
8882
return node
8983

9084

@@ -95,14 +89,15 @@ class TestGraph:
9589
2,
9690
3,
9791
])
98-
def test_override_fused_node_activation_quantization_candidates(self, idx):
92+
def test_override_fused_node_activation_quantization_candidates(self, idx, patch_fw_info):
9993
"""
10094
Test the override_fused_node_activation_quantization_candidates function for a graph with multiple nodes and configurations.
10195
"""
10296
### Create Test Nodes
10397
mock_nodes = []
104-
mock_nodes.append(build_mock_node(name='conv', layer_class='Conv2d'))
105-
mock_nodes.append(build_mock_node(name='fc', layer_class='Linear'))
98+
w_cfgs = [Mock(), Mock()]
99+
mock_nodes.append(build_mock_node(name='conv', layer_class='Conv2d', w_cfgs=w_cfgs))
100+
mock_nodes.append(build_mock_node(name='fc', layer_class='Linear', w_cfgs=w_cfgs[:1]))
106101

107102
### Create a mock graph
108103
### Note: Generate the graph first because fusing_info cannot be set without it.
@@ -120,14 +115,33 @@ def test_override_fused_node_activation_quantization_candidates(self, idx):
120115
nodes = list(graph.nodes)
121116

122117
if idx == 1:
123-
### Check if the first node ActivationQuantization settings match the expected values
124-
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_QUANT
125-
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.activation_n_bits == 16
126-
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.signedness == Signedness.AUTO
127-
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
128-
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params_fn == power_of_two_selection_histogram
118+
# Check if the first node ActivationQuantization settings match the expected values
119+
# Weight mp configs are preserved, all candidates have the new activation config and duplicates are removed
120+
qcs0 = nodes[0].quantization_cfg.candidates_quantization_cfg
121+
assert len(qcs0) == 2
122+
for i, qc in enumerate(qcs0):
123+
assert qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_QUANT
124+
assert qc.activation_quantization_cfg.activation_n_bits == 16
125+
assert qc.activation_quantization_cfg.signedness == Signedness.AUTO
126+
assert qc.activation_quantization_cfg.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
127+
assert qc.activation_quantization_cfg.activation_quantization_params_fn == power_of_two_selection_histogram
128+
assert qc.weights_quantization_cfg == w_cfgs[i]
129+
base_cfg0 = nodes[0].quantization_cfg.base_quantization_cfg
130+
assert base_cfg0.activation_quantization_cfg.activation_n_bits == 16
131+
assert base_cfg0.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_QUANT
129132
### Check if the second node ActivationQuantization settings match the expected values
130-
assert nodes[1].candidates_quantization_cfg[0].activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_NO_QUANT
133+
# activations are fln-disabled, duplicates are removed even though orig activation configs differ in nbits
134+
qcs1 = nodes[1].quantization_cfg.candidates_quantization_cfg
135+
assert len(qcs1) == 1
136+
assert qcs1[0].activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_NO_QUANT
137+
assert qcs1[0].weights_quantization_cfg == w_cfgs[0]
138+
assert (nodes[1].quantization_cfg.base_quantization_cfg.
139+
activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_NO_QUANT)
140+
131141
else:
132142
### Check if the first node ActivationQuantization settings match the expected values
133-
assert nodes[0].candidates_quantization_cfg[0].activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_NO_QUANT
143+
qcs0 = nodes[0].quantization_cfg.candidates_quantization_cfg
144+
assert len(qcs0) == 2
145+
for i, qc in enumerate(qcs0):
146+
assert qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.FLN_NO_QUANT
147+
assert qc.weights_quantization_cfg == w_cfgs[i]

0 commit comments

Comments
 (0)