1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414# ==============================================================================
15+ import itertools
1516from copy import deepcopy
1617
1718import pytest
18- from unittest .mock import Mock
19+ from unittest .mock import Mock , PropertyMock
1920
2021from mct_quantizers import QuantizationMethod
2122from model_compression_toolkit .core .common import Graph
2425from model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema import Signedness
2526from tests .common_tests .helpers .generate_test_tpc import generate_test_attr_configs , generate_test_op_qc
2627from model_compression_toolkit .core .common .quantization .node_quantization_config import ActivationQuantizationMode , NodeActivationQuantizationConfig
27- from model_compression_toolkit .core .common .quantization .candidate_node_quantization_config import CandidateNodeQuantizationConfig
28+ from model_compression_toolkit .core .common .quantization .candidate_node_quantization_config import \
29+ CandidateNodeQuantizationConfig , NodeQuantizationConfig
2830from model_compression_toolkit .core .common .quantization .quantization_params_generation .power_of_two_selection import power_of_two_selection_histogram
2931from model_compression_toolkit .core .common .quantization .quantization_params_generation .symmetric_selection import symmetric_selection_histogram
3032from model_compression_toolkit .core import QuantizationErrorMethod
33+ from tests_pytest ._test_util .graph_builder_utils import build_node , build_nbits_qc
34+
3135
3236def build_mock_fusing_info (nodes , idx ):
3337 """
@@ -55,36 +59,26 @@ def build_mock_fusing_info(nodes, idx):
5559
5660 return fusing_info
5761
58- def build_mock_node (name , layer_class ):
62+ def build_mock_node (name , layer_class , w_cfgs ):
5963 """
6064 Creates mock nodes representing a simple neural network structure.
6165 """
62- node = Mock (spec = BaseNode )
63- node .name = name
64- node .layer_class = layer_class
65-
66- activation_quantization_cfg = Mock (spec = NodeActivationQuantizationConfig )
67- activation_quantization_cfg .quant_mode = Mock ()
68- activation_quantization_cfg .activation_quantization_fn = symmetric_selection_histogram
69- activation_quantization_cfg .activation_quantization_params_fn = power_of_two_selection_histogram
70-
71- candidate_quantization_config = Mock (spec = CandidateNodeQuantizationConfig )
72- candidate_quantization_config .activation_quantization_cfg = activation_quantization_cfg
73- candidate_quantization_config .activation_error_method = QuantizationErrorMethod .MSE
74- candidate_quantization_config .relu_bound_to_power_of_2 = 0
75- candidate_quantization_config .activation_channel_equalization = False
76- candidate_quantization_config .input_scaling = False
77- candidate_quantization_config .min_threshold = 0
78- candidate_quantization_config .l_p_value = 0
79- candidate_quantization_config .shift_negative_activation_correction = 0
80- candidate_quantization_config .z_threshold = 0
81- candidate_quantization_config .shift_negative_ratio = 0
82- candidate_quantization_config .shift_negative_threshold_recalculation = 0
83- candidate_quantization_config .concat_threshold_update = 0
84- candidate_quantization_config .weights_quantization_cfg = 0
85-
86- node .candidates_quantization_cfg = [candidate_quantization_config ]
66+ node = build_node (name , layer_class = layer_class )
67+
68+ def eq (self_ , other ):
69+ return self_ .activation_n_bits == other .activation_n_bits and self_ ._quant_mode == other .quant_mode
70+ a_cfgs = [Mock (spec = NodeActivationQuantizationConfig ,
71+ quant_mode = Mock (),
72+ activation_n_bits = b ,
73+ activation_quantization_fn = symmetric_selection_histogram ,
74+ activation_quantization_params_fn = power_of_two_selection_histogram ,
75+ __eq__ = eq ) for b in [5 , 6 ]]
76+
77+ qcs = [CandidateNodeQuantizationConfig (a_cfg , w_cfg ) for a_cfg , w_cfg in itertools .product (a_cfgs , w_cfgs )]
8778
79+ node .quantization_cfg = NodeQuantizationConfig (base_quantization_cfg = qcs [0 ],
80+ candidates_quantization_cfg = qcs ,
81+ validate = False )
8882 return node
8983
9084
@@ -95,14 +89,15 @@ class TestGraph:
9589 2 ,
9690 3 ,
9791 ])
98- def test_override_fused_node_activation_quantization_candidates (self , idx ):
92+ def test_override_fused_node_activation_quantization_candidates (self , idx , patch_fw_info ):
9993 """
10094 Test the override_fused_node_activation_quantization_candidates function for a graph with multiple nodes and configurations.
10195 """
10296 ### Create Test Nodes
10397 mock_nodes = []
104- mock_nodes .append (build_mock_node (name = 'conv' , layer_class = 'Conv2d' ))
105- mock_nodes .append (build_mock_node (name = 'fc' , layer_class = 'Linear' ))
98+ w_cfgs = [Mock (), Mock ()]
99+ mock_nodes .append (build_mock_node (name = 'conv' , layer_class = 'Conv2d' , w_cfgs = w_cfgs ))
100+ mock_nodes .append (build_mock_node (name = 'fc' , layer_class = 'Linear' , w_cfgs = w_cfgs [:1 ]))
106101
107102 ### Create a mock graph
108103 ### Note: Generate the graph first because fusing_info cannot be set without it.
@@ -120,14 +115,33 @@ def test_override_fused_node_activation_quantization_candidates(self, idx):
120115 nodes = list (graph .nodes )
121116
122117 if idx == 1 :
123- ### Check if the first node ActivationQuantization settings match the expected values
124- assert nodes [0 ].candidates_quantization_cfg [0 ].activation_quantization_cfg .quant_mode == ActivationQuantizationMode .FLN_QUANT
125- assert nodes [0 ].candidates_quantization_cfg [0 ].activation_quantization_cfg .activation_n_bits == 16
126- assert nodes [0 ].candidates_quantization_cfg [0 ].activation_quantization_cfg .signedness == Signedness .AUTO
127- assert nodes [0 ].candidates_quantization_cfg [0 ].activation_quantization_cfg .activation_quantization_method == QuantizationMethod .POWER_OF_TWO
128- assert nodes [0 ].candidates_quantization_cfg [0 ].activation_quantization_cfg .activation_quantization_params_fn == power_of_two_selection_histogram
118+ # Check if the first node ActivationQuantization settings match the expected values
119+ # Weight mp configs are preserved, all candidates have the new activation config and duplicates are removed
120+ qcs0 = nodes [0 ].quantization_cfg .candidates_quantization_cfg
121+ assert len (qcs0 ) == 2
122+ for i , qc in enumerate (qcs0 ):
123+ assert qc .activation_quantization_cfg .quant_mode == ActivationQuantizationMode .FLN_QUANT
124+ assert qc .activation_quantization_cfg .activation_n_bits == 16
125+ assert qc .activation_quantization_cfg .signedness == Signedness .AUTO
126+ assert qc .activation_quantization_cfg .activation_quantization_method == QuantizationMethod .POWER_OF_TWO
127+ assert qc .activation_quantization_cfg .activation_quantization_params_fn == power_of_two_selection_histogram
128+ assert qc .weights_quantization_cfg == w_cfgs [i ]
129+ base_cfg0 = nodes [0 ].quantization_cfg .base_quantization_cfg
130+ assert base_cfg0 .activation_quantization_cfg .activation_n_bits == 16
131+ assert base_cfg0 .activation_quantization_cfg .quant_mode == ActivationQuantizationMode .FLN_QUANT
129132 ### Check if the second node ActivationQuantization settings match the expected values
130- assert nodes [1 ].candidates_quantization_cfg [0 ].activation_quantization_cfg .quant_mode == ActivationQuantizationMode .FLN_NO_QUANT
133+ # activations are fln-disabled, duplicates are removed even though orig activation configs differ in nbits
134+ qcs1 = nodes [1 ].quantization_cfg .candidates_quantization_cfg
135+ assert len (qcs1 ) == 1
136+ assert qcs1 [0 ].activation_quantization_cfg .quant_mode == ActivationQuantizationMode .FLN_NO_QUANT
137+ assert qcs1 [0 ].weights_quantization_cfg == w_cfgs [0 ]
138+ assert (nodes [1 ].quantization_cfg .base_quantization_cfg .
139+ activation_quantization_cfg .quant_mode == ActivationQuantizationMode .FLN_NO_QUANT )
140+
131141 else :
132142 ### Check if the first node ActivationQuantization settings match the expected values
133- assert nodes [0 ].candidates_quantization_cfg [0 ].activation_quantization_cfg .quant_mode == ActivationQuantizationMode .FLN_NO_QUANT
143+ qcs0 = nodes [0 ].quantization_cfg .candidates_quantization_cfg
144+ assert len (qcs0 ) == 2
145+ for i , qc in enumerate (qcs0 ):
146+ assert qc .activation_quantization_cfg .quant_mode == ActivationQuantizationMode .FLN_NO_QUANT
147+ assert qc .weights_quantization_cfg == w_cfgs [i ]
0 commit comments