Skip to content

Commit 25853f2

Browse files
author
liord
committed
Add quantization error method E2E test. Add test to check the use of correct histogram/weighted histogram
1 parent 52761ba commit 25853f2

File tree

2 files changed

+244
-0
lines changed

2 files changed

+244
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import math
16+
from unittest.mock import Mock, MagicMock
17+
18+
import numpy as np
19+
import pytest
20+
import torch
21+
22+
from mct_quantizers import QuantizationMethod
23+
from torch import nn
24+
25+
from model_compression_toolkit.core import QuantizationErrorMethod, QuantizationConfig, CoreConfig
26+
from model_compression_toolkit.core.common import StatsCollector
27+
from model_compression_toolkit.core.common.collectors.histogram_collector import HistogramCollector
28+
from model_compression_toolkit.core.common.collectors.weighted_histogram_collector import WeightedHistogramCollector
29+
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
30+
from model_compression_toolkit.ptq import pytorch_post_training_quantization
31+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
32+
AttributeQuantizationConfig, Signedness
33+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR
34+
from tests.common_tests.helpers.tpcs_for_tests.v4.tpc import generate_tpc
35+
36+
INPUT_SHAPE = (1, 8, 12, 8)
37+
torch.manual_seed(42)
38+
39+
40+
def build_model(input_shape, out_chan, kernel, const):
41+
"""
42+
Build a simple CNN model with a convolution layer and ReLU activation.
43+
44+
The convolution weights are initialized with a constant normalized sequential tensor.
45+
"""
46+
class Model(nn.Module):
47+
def __init__(self):
48+
super().__init__()
49+
self.conv = nn.Conv2d(in_channels=input_shape[1], out_channels=out_chan, kernel_size=kernel)
50+
self.relu = nn.ReLU()
51+
52+
# Initialize convolution weights with sequential numbers from 1 to total number of weights
53+
total_weights = self.conv.weight.numel()
54+
# Create a tensor with values from 1 to total_weights and reshape it to the shape of conv weights
55+
weights_seq = torch.arange(1, total_weights + 1, dtype=self.conv.weight.dtype).reshape(
56+
self.conv.weight.shape)
57+
weights_seq = weights_seq/weights_seq.sum()
58+
with torch.no_grad():
59+
self.conv.weight.copy_(weights_seq)
60+
61+
def forward(self, x):
62+
x = self.conv(x)
63+
x = self.relu(x)
64+
x = torch.add(x, to_torch_tensor(const))
65+
return x
66+
67+
return Model()
68+
69+
70+
@pytest.fixture
71+
def rep_data_gen():
72+
"""
73+
Fixture to create a representative dataset generator for post-training quantization.
74+
75+
Generates a small dataset based on the defined INPUT_SHAPE.
76+
"""
77+
np.random.seed(42)
78+
79+
def representative_dataset():
80+
for _ in range(2):
81+
yield [np.random.randn(*INPUT_SHAPE)]
82+
83+
return representative_dataset
84+
85+
86+
def get_tpc():
87+
"""
88+
Create a target platform capabilities (TPC) configuration with no weight quantization.
89+
90+
Returns a TPC object for quantization tests.
91+
"""
92+
att_cfg_noquant = AttributeQuantizationConfig()
93+
94+
op_cfg = OpQuantizationConfig(default_weight_attr_config=att_cfg_noquant,
95+
attr_weights_configs_mapping={KERNEL_ATTR: att_cfg_noquant,
96+
BIAS_ATTR: att_cfg_noquant},
97+
activation_quantization_method=QuantizationMethod.UNIFORM,
98+
activation_n_bits=2,
99+
supported_input_activation_n_bits=2,
100+
enable_activation_quantization=True,
101+
quantization_preserving=True,
102+
fixed_scale=None,
103+
fixed_zero_point=None,
104+
simd_size=32,
105+
signedness=Signedness.AUTO)
106+
107+
tpc = generate_tpc(default_config=op_cfg, base_config=op_cfg, mixed_precision_cfg_list=[op_cfg], name="test_tpc")
108+
109+
return tpc
110+
111+
112+
def get_core_config(quant_error_method):
113+
"""
114+
Create a core configuration with a specified quantization error method.
115+
116+
Parameters:
117+
quant_error_method: QuantizationErrorMethod to be used in the configuration.
118+
119+
Returns:
120+
CoreConfig instance configured with the specified quantization error method.
121+
"""
122+
quantization_config = QuantizationConfig(activation_error_method=quant_error_method)
123+
return CoreConfig(quantization_config=quantization_config)
124+
125+
126+
127+
@pytest.fixture(params=[QuantizationErrorMethod.MSE,
128+
QuantizationErrorMethod.LP,
129+
QuantizationErrorMethod.HMSE,
130+
QuantizationErrorMethod.MAE,
131+
QuantizationErrorMethod.NOCLIPPING])
132+
def quant_error_method(request):
133+
return request.param
134+
135+
136+
def compute_max_range(model, tpc, rep_data_gen, quant_error_method):
137+
"""
138+
Build a model, apply post-training quantization with the given error method,
139+
and return the max_range of the ReLU activation quantizer.
140+
"""
141+
q_model, _ = pytorch_post_training_quantization(
142+
in_module=model,
143+
core_config=get_core_config(quant_error_method),
144+
representative_data_gen=rep_data_gen,
145+
target_platform_capabilities=tpc
146+
)
147+
return q_model.relu_activation_holder_quantizer.activation_holder_quantizer.max_range
148+
149+
150+
class TestPTQWithActivationQuantizationErrorMethods:
151+
# Parameterize over every distinct pair of quantization methods.
152+
@pytest.mark.parametrize("method1, method2", [
153+
(QuantizationErrorMethod.MSE, QuantizationErrorMethod.LP),
154+
(QuantizationErrorMethod.MSE, QuantizationErrorMethod.HMSE),
155+
(QuantizationErrorMethod.MSE, QuantizationErrorMethod.MAE),
156+
(QuantizationErrorMethod.MSE, QuantizationErrorMethod.NOCLIPPING),
157+
(QuantizationErrorMethod.LP, QuantizationErrorMethod.HMSE),
158+
(QuantizationErrorMethod.LP, QuantizationErrorMethod.MAE),
159+
(QuantizationErrorMethod.LP, QuantizationErrorMethod.NOCLIPPING),
160+
(QuantizationErrorMethod.HMSE, QuantizationErrorMethod.MAE),
161+
(QuantizationErrorMethod.HMSE, QuantizationErrorMethod.NOCLIPPING),
162+
(QuantizationErrorMethod.MAE, QuantizationErrorMethod.NOCLIPPING),
163+
])
164+
def test_ptq_quantization_error_methods(self, rep_data_gen, method1, method2):
165+
"""
166+
Verify that post-training quantization produces different quantization parameters based
167+
on the chosen quantization error method.
168+
169+
For each quantization error method, this test builds a model,
170+
applies post-training quantization, and records the max_range of the ReLU activation quantizer.
171+
It then asserts that each method produces a distinct quantization parameter.
172+
"""
173+
model = build_model(input_shape=INPUT_SHAPE, out_chan=16, kernel=1, const=np.array([5]))
174+
tpc = get_tpc()
175+
max_range1 = compute_max_range(model, tpc, rep_data_gen, method1)
176+
max_range2 = compute_max_range(model, tpc, rep_data_gen, method2)
177+
assert max_range1 != max_range2, ((
178+
f"Methods {method1} and {method2} produced the same max_range value."
179+
))
180+
181+
@pytest.mark.parametrize("method, use_weighted", [
182+
(QuantizationErrorMethod.HMSE, True),
183+
(QuantizationErrorMethod.MSE, False),
184+
(QuantizationErrorMethod.LP, False),
185+
(QuantizationErrorMethod.MAE, False),
186+
(QuantizationErrorMethod.NOCLIPPING, False)
187+
])
188+
def test_ptq_use_of_histogram_collector_for_quantization_error_methods(self, monkeypatch, rep_data_gen, method, use_weighted):
189+
"""
190+
E2E test to verify that the correct histogram collector is used
191+
based on the quantization error method chosen during post-training
192+
quantization. The test replaces the get_histogram() methods with mocks
193+
and asserts that the corresponding collector is invoked.
194+
"""
195+
196+
# Create dummy 1D numpy arrays to simulate valid histogram bins and counts.
197+
dummy_bins_weighted = np.array([0.0, 1.0, 2.0, 3.0])
198+
dummy_counts_weighted = np.array([10, 20, 15])
199+
dummy_bins_regular = np.array([0.0, 1.0, 2.0, 3.0])
200+
dummy_counts_regular = np.array([5, 25, 10])
201+
202+
# Create spy mocks that return valid numpy arrays.
203+
spy_weighted_get_histogram = MagicMock(return_value=(dummy_bins_weighted, dummy_counts_weighted))
204+
spy_regular_get_histogram = MagicMock(return_value=(dummy_bins_regular, dummy_counts_regular))
205+
206+
# Monkeypatch the get_histogram methods in the collector classes.
207+
monkeypatch.setattr(WeightedHistogramCollector, "get_histogram", spy_weighted_get_histogram)
208+
monkeypatch.setattr(HistogramCollector, "get_histogram", spy_regular_get_histogram)
209+
210+
# Build the model and acquire the target platform capabilities.
211+
model = build_model(input_shape=INPUT_SHAPE, out_chan=16, kernel=1, const=np.array([5]))
212+
tpc = get_tpc()
213+
214+
# Execute post-training quantization that internally calls get_histogram_data.
215+
q_model, _ = pytorch_post_training_quantization(
216+
in_module=model,
217+
core_config=get_core_config(method),
218+
representative_data_gen=rep_data_gen,
219+
target_platform_capabilities=tpc
220+
)
221+
222+
# Validate that the correct histogram collector is used based on hist_type.
223+
if use_weighted:
224+
# For weighted histogram collection, the regular method should not be called.
225+
assert spy_weighted_get_histogram.call_count >= 3, "Expected at least three calls to weighted get_histogram"
226+
spy_regular_get_histogram.assert_not_called()
227+
else:
228+
# For regular histogram collection, the weighted method should not be called.
229+
spy_weighted_get_histogram.assert_not_called()
230+
assert spy_regular_get_histogram.call_count >= 3, "Expected at least three calls to get_histogram"

0 commit comments

Comments
 (0)