1+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ # ==============================================================================
15+ import math
16+ from unittest .mock import Mock , MagicMock
17+
18+ import numpy as np
19+ import pytest
20+ import torch
21+
22+ from mct_quantizers import QuantizationMethod
23+ from torch import nn
24+
25+ from model_compression_toolkit .core import QuantizationErrorMethod , QuantizationConfig , CoreConfig
26+ from model_compression_toolkit .core .common import StatsCollector
27+ from model_compression_toolkit .core .common .collectors .histogram_collector import HistogramCollector
28+ from model_compression_toolkit .core .common .collectors .weighted_histogram_collector import WeightedHistogramCollector
29+ from model_compression_toolkit .core .pytorch .utils import to_torch_tensor
30+ from model_compression_toolkit .ptq import pytorch_post_training_quantization
31+ from model_compression_toolkit .target_platform_capabilities .schema .mct_current_schema import OpQuantizationConfig , \
32+ AttributeQuantizationConfig , Signedness
33+ from model_compression_toolkit .target_platform_capabilities .constants import KERNEL_ATTR , BIAS_ATTR
34+ from tests .common_tests .helpers .tpcs_for_tests .v4 .tpc import generate_tpc
35+
36+ INPUT_SHAPE = (1 , 8 , 12 , 8 )
37+ torch .manual_seed (42 )
38+
39+
40+ def build_model (input_shape , out_chan , kernel , const ):
41+ """
42+ Build a simple CNN model with a convolution layer and ReLU activation.
43+
44+ The convolution weights are initialized with a constant normalized sequential tensor.
45+ """
46+ class Model (nn .Module ):
47+ def __init__ (self ):
48+ super ().__init__ ()
49+ self .conv = nn .Conv2d (in_channels = input_shape [1 ], out_channels = out_chan , kernel_size = kernel )
50+ self .relu = nn .ReLU ()
51+
52+ # Initialize convolution weights with sequential numbers from 1 to total number of weights
53+ total_weights = self .conv .weight .numel ()
54+ # Create a tensor with values from 1 to total_weights and reshape it to the shape of conv weights
55+ weights_seq = torch .arange (1 , total_weights + 1 , dtype = self .conv .weight .dtype ).reshape (
56+ self .conv .weight .shape )
57+ weights_seq = weights_seq / weights_seq .sum ()
58+ with torch .no_grad ():
59+ self .conv .weight .copy_ (weights_seq )
60+
61+ def forward (self , x ):
62+ x = self .conv (x )
63+ x = self .relu (x )
64+ x = torch .add (x , to_torch_tensor (const ))
65+ return x
66+
67+ return Model ()
68+
69+
70+ @pytest .fixture
71+ def rep_data_gen ():
72+ """
73+ Fixture to create a representative dataset generator for post-training quantization.
74+
75+ Generates a small dataset based on the defined INPUT_SHAPE.
76+ """
77+ np .random .seed (42 )
78+
79+ def representative_dataset ():
80+ for _ in range (2 ):
81+ yield [np .random .randn (* INPUT_SHAPE )]
82+
83+ return representative_dataset
84+
85+
86+ def get_tpc ():
87+ """
88+ Create a target platform capabilities (TPC) configuration with no weight quantization.
89+
90+ Returns a TPC object for quantization tests.
91+ """
92+ att_cfg_noquant = AttributeQuantizationConfig ()
93+
94+ op_cfg = OpQuantizationConfig (default_weight_attr_config = att_cfg_noquant ,
95+ attr_weights_configs_mapping = {KERNEL_ATTR : att_cfg_noquant ,
96+ BIAS_ATTR : att_cfg_noquant },
97+ activation_quantization_method = QuantizationMethod .UNIFORM ,
98+ activation_n_bits = 2 ,
99+ supported_input_activation_n_bits = 2 ,
100+ enable_activation_quantization = True ,
101+ quantization_preserving = True ,
102+ fixed_scale = None ,
103+ fixed_zero_point = None ,
104+ simd_size = 32 ,
105+ signedness = Signedness .AUTO )
106+
107+ tpc = generate_tpc (default_config = op_cfg , base_config = op_cfg , mixed_precision_cfg_list = [op_cfg ], name = "test_tpc" )
108+
109+ return tpc
110+
111+
112+ def get_core_config (quant_error_method ):
113+ """
114+ Create a core configuration with a specified quantization error method.
115+
116+ Parameters:
117+ quant_error_method: QuantizationErrorMethod to be used in the configuration.
118+
119+ Returns:
120+ CoreConfig instance configured with the specified quantization error method.
121+ """
122+ quantization_config = QuantizationConfig (activation_error_method = quant_error_method )
123+ return CoreConfig (quantization_config = quantization_config )
124+
125+
126+
127+ @pytest .fixture (params = [QuantizationErrorMethod .MSE ,
128+ QuantizationErrorMethod .LP ,
129+ QuantizationErrorMethod .HMSE ,
130+ QuantizationErrorMethod .MAE ,
131+ QuantizationErrorMethod .NOCLIPPING ])
132+ def quant_error_method (request ):
133+ return request .param
134+
135+
136+ def compute_max_range (model , tpc , rep_data_gen , quant_error_method ):
137+ """
138+ Build a model, apply post-training quantization with the given error method,
139+ and return the max_range of the ReLU activation quantizer.
140+ """
141+ q_model , _ = pytorch_post_training_quantization (
142+ in_module = model ,
143+ core_config = get_core_config (quant_error_method ),
144+ representative_data_gen = rep_data_gen ,
145+ target_platform_capabilities = tpc
146+ )
147+ return q_model .relu_activation_holder_quantizer .activation_holder_quantizer .max_range
148+
149+
150+ class TestPTQWithActivationQuantizationErrorMethods :
151+ # Parameterize over every distinct pair of quantization methods.
152+ @pytest .mark .parametrize ("method1, method2" , [
153+ (QuantizationErrorMethod .MSE , QuantizationErrorMethod .LP ),
154+ (QuantizationErrorMethod .MSE , QuantizationErrorMethod .HMSE ),
155+ (QuantizationErrorMethod .MSE , QuantizationErrorMethod .MAE ),
156+ (QuantizationErrorMethod .MSE , QuantizationErrorMethod .NOCLIPPING ),
157+ (QuantizationErrorMethod .LP , QuantizationErrorMethod .HMSE ),
158+ (QuantizationErrorMethod .LP , QuantizationErrorMethod .MAE ),
159+ (QuantizationErrorMethod .LP , QuantizationErrorMethod .NOCLIPPING ),
160+ (QuantizationErrorMethod .HMSE , QuantizationErrorMethod .MAE ),
161+ (QuantizationErrorMethod .HMSE , QuantizationErrorMethod .NOCLIPPING ),
162+ (QuantizationErrorMethod .MAE , QuantizationErrorMethod .NOCLIPPING ),
163+ ])
164+ def test_ptq_quantization_error_methods (self , rep_data_gen , method1 , method2 ):
165+ """
166+ Verify that post-training quantization produces different quantization parameters based
167+ on the chosen quantization error method.
168+
169+ For each quantization error method, this test builds a model,
170+ applies post-training quantization, and records the max_range of the ReLU activation quantizer.
171+ It then asserts that each method produces a distinct quantization parameter.
172+ """
173+ model = build_model (input_shape = INPUT_SHAPE , out_chan = 16 , kernel = 1 , const = np .array ([5 ]))
174+ tpc = get_tpc ()
175+ max_range1 = compute_max_range (model , tpc , rep_data_gen , method1 )
176+ max_range2 = compute_max_range (model , tpc , rep_data_gen , method2 )
177+ assert max_range1 != max_range2 , ((
178+ f"Methods { method1 } and { method2 } produced the same max_range value."
179+ ))
180+
181+ @pytest .mark .parametrize ("method, use_weighted" , [
182+ (QuantizationErrorMethod .HMSE , True ),
183+ (QuantizationErrorMethod .MSE , False ),
184+ (QuantizationErrorMethod .LP , False ),
185+ (QuantizationErrorMethod .MAE , False ),
186+ (QuantizationErrorMethod .NOCLIPPING , False )
187+ ])
188+ def test_ptq_use_of_histogram_collector_for_quantization_error_methods (self , monkeypatch , rep_data_gen , method , use_weighted ):
189+ """
190+ E2E test to verify that the correct histogram collector is used
191+ based on the quantization error method chosen during post-training
192+ quantization. The test replaces the get_histogram() methods with mocks
193+ and asserts that the corresponding collector is invoked.
194+ """
195+
196+ # Create dummy 1D numpy arrays to simulate valid histogram bins and counts.
197+ dummy_bins_weighted = np .array ([0.0 , 1.0 , 2.0 , 3.0 ])
198+ dummy_counts_weighted = np .array ([10 , 20 , 15 ])
199+ dummy_bins_regular = np .array ([0.0 , 1.0 , 2.0 , 3.0 ])
200+ dummy_counts_regular = np .array ([5 , 25 , 10 ])
201+
202+ # Create spy mocks that return valid numpy arrays.
203+ spy_weighted_get_histogram = MagicMock (return_value = (dummy_bins_weighted , dummy_counts_weighted ))
204+ spy_regular_get_histogram = MagicMock (return_value = (dummy_bins_regular , dummy_counts_regular ))
205+
206+ # Monkeypatch the get_histogram methods in the collector classes.
207+ monkeypatch .setattr (WeightedHistogramCollector , "get_histogram" , spy_weighted_get_histogram )
208+ monkeypatch .setattr (HistogramCollector , "get_histogram" , spy_regular_get_histogram )
209+
210+ # Build the model and acquire the target platform capabilities.
211+ model = build_model (input_shape = INPUT_SHAPE , out_chan = 16 , kernel = 1 , const = np .array ([5 ]))
212+ tpc = get_tpc ()
213+
214+ # Execute post-training quantization that internally calls get_histogram_data.
215+ q_model , _ = pytorch_post_training_quantization (
216+ in_module = model ,
217+ core_config = get_core_config (method ),
218+ representative_data_gen = rep_data_gen ,
219+ target_platform_capabilities = tpc
220+ )
221+
222+ # Validate that the correct histogram collector is used based on hist_type.
223+ if use_weighted :
224+ # For weighted histogram collection, the regular method should not be called.
225+ assert spy_weighted_get_histogram .call_count >= 3 , "Expected at least three calls to weighted get_histogram"
226+ spy_regular_get_histogram .assert_not_called ()
227+ else :
228+ # For regular histogram collection, the weighted method should not be called.
229+ spy_weighted_get_histogram .assert_not_called ()
230+ assert spy_regular_get_histogram .call_count >= 3 , "Expected at least three calls to get_histogram"
0 commit comments