Skip to content

Commit 321bd87

Browse files
Split and rename tests
1 parent d53e776 commit 321bd87

File tree

12 files changed

+1102
-301
lines changed

12 files changed

+1102
-301
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Solutions. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================

tests_pytest/wrapper_tests/test_mct_wrapper_keras_e2e.py renamed to tests_pytest/keras_tests/e2e_tests/wrapper/test_mct_wrapper_keras_e2e.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def PTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
211211
method = 'PTQ'
212212
framework = 'tensorflow'
213213
use_MCT_TPC = True
214-
use_MixP = False
214+
use_mixed_precision = False
215215

216216
# Configure quantization parameters for optimal model performance
217217
param_items = [['tpc_version', '1.0', 'The version of the TPC to use.'],
@@ -226,7 +226,7 @@ def PTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
226226

227227
# Execute quantization using MCTWrapper
228228
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
229-
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_MixP, representative_dataset_gen, param_items)
229+
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_mixed_precision, representative_dataset_gen, param_items)
230230
return flag, quantized_model
231231

232232
#########################################################################
@@ -243,7 +243,7 @@ def PTQ_Keras_MixP(float_model: keras.Model) -> Tuple[bool, keras.Model]:
243243
method = 'PTQ'
244244
framework = 'tensorflow'
245245
use_MCT_TPC = True
246-
use_MixP = True
246+
use_mixed_precision = True
247247

248248
# Configure mixed precision parameters for optimal compression
249249
param_items = [['tpc_version', '1.0', 'The version of the TPC to use.'],
@@ -256,7 +256,7 @@ def PTQ_Keras_MixP(float_model: keras.Model) -> Tuple[bool, keras.Model]:
256256

257257
# Execute quantization with mixed precision using MCTWrapper
258258
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
259-
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_MixP, representative_dataset_gen, param_items)
259+
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_mixed_precision, representative_dataset_gen, param_items)
260260
return flag, quantized_model
261261

262262
#########################################################################
@@ -273,7 +273,7 @@ def GPTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
273273
method = 'GPTQ'
274274
framework = 'tensorflow'
275275
use_MCT_TPC = True
276-
use_MixP = False
276+
use_mixed_precision = False
277277

278278
# Configure GPTQ-specific parameters for gradient-based optimization
279279
param_items = [['target_platform_version', 'v1', 'Target platform capabilities version.'],
@@ -285,7 +285,7 @@ def GPTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
285285

286286
# Execute gradient-based quantization using MCTWrapper
287287
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
288-
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_MixP, representative_dataset_gen, param_items)
288+
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_mixed_precision, representative_dataset_gen, param_items)
289289
return flag, quantized_model
290290

291291
#########################################################################
@@ -295,7 +295,7 @@ def GPTQ_Keras_MixP(float_model: keras.Model) -> Tuple[bool, keras.Model]:
295295
method = 'GPTQ'
296296
framework = 'tensorflow'
297297
use_MCT_TPC = True
298-
use_MixP = True
298+
use_mixed_precision = True
299299

300300
param_items = [['target_platform_version', 'v1', 'Target platform capabilities version.'],
301301

@@ -310,7 +310,7 @@ def GPTQ_Keras_MixP(float_model: keras.Model) -> Tuple[bool, keras.Model]:
310310
['save_model_path', './qmodel_GPTQ_Keras_MixP.tflite', 'Path to save the model.']]
311311

312312
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
313-
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_MixP, representative_dataset_gen, param_items)
313+
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_mixed_precision, representative_dataset_gen, param_items)
314314
return flag, quantized_model
315315

316316
#########################################################################
@@ -320,7 +320,7 @@ def LQPTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
320320
method = 'LQPTQ'
321321
framework = 'tensorflow'
322322
use_MCT_TPC = True
323-
use_MixP = False
323+
use_mixed_precision = False
324324

325325
param_items = [
326326

@@ -333,7 +333,7 @@ def LQPTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
333333
# part as a NumPy array
334334
representative_dataset = dataset.take(1).get_single_element()[0].numpy()
335335
wrapper = mct.wrapper.wrap.MCTWrapper()
336-
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_MixP, representative_dataset, param_items)
336+
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_mixed_precision, representative_dataset, param_items)
337337
return flag, quantized_model
338338

339339
# Execute the selected quantization method
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Solutions. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
"""
17+
Test cases for MCTWrapper class from model_compression_toolkit.wrapper.mct_wrapper
18+
"""
19+
20+
import pytest
21+
from unittest.mock import Mock, patch
22+
from typing import Any, List, Tuple
23+
from model_compression_toolkit.wrapper.mct_wrapper import MCTWrapper
24+
25+
26+
class TestMCTWrapperIntegration:
27+
"""
28+
Integration Tests for MCTWrapper Complete Workflows
29+
30+
This test class focuses on testing the complete quantization and export
31+
workflows by testing the main quantize_and_export method with different
32+
configurations and scenarios.
33+
34+
Test Categories:
35+
- PTQ Workflow: Complete Post-Training Quantization flow
36+
- GPTQ Mixed Precision: Gradient PTQ with mixed precision
37+
- LQ-PTQ TensorFlow: Low-bit quantization specific to TensorFlow
38+
"""
39+
40+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
41+
'MCTWrapper._get_TPC')
42+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
43+
'MCTWrapper._select_method')
44+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
45+
'MCTWrapper.select_argname')
46+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
47+
'MCTWrapper._setting_PTQ')
48+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
49+
'MCTWrapper._export_model')
50+
def test_quantize_and_export_PTQ_flow(
51+
self, mock_export: Mock, mock_setting_ptq: Mock,
52+
mock_select_argname: Mock, mock_select_method: Mock,
53+
mock_get_tpc: Mock) -> None:
54+
"""
55+
Test complete quantize_and_export workflow for Post-Training Quantization.
56+
57+
This integration test verifies the complete PTQ workflow from input
58+
validation through model export. It mocks internal methods to focus
59+
on workflow coordination and method call sequences.
60+
61+
Workflow Steps Tested:
62+
1. Input validation and initialization
63+
2. Parameter modification
64+
3. Method selection for framework and quantization type
65+
4. TPC (Target Platform Capabilities) configuration
66+
5. PTQ parameter setup
67+
6. Model quantization execution
68+
7. Model export
69+
70+
Mocked Components:
71+
- _get_TPC: TPC configuration
72+
- _select_method: Framework-specific method selection
73+
- _Setting_PTQ: PTQ parameter configuration
74+
- _export_model: Model export functionality
75+
- _post_training_quantization: Actual quantization process
76+
77+
Verification Points:
78+
- Correct method call sequence
79+
- Proper parameter passing between methods
80+
- Expected return values (success flag and quantized model)
81+
- Instance state consistency after workflow completion
82+
"""
83+
wrapper = MCTWrapper()
84+
85+
# Setup mocks
86+
mock_float_model = Mock()
87+
mock_representative_dataset = Mock()
88+
mock_quantized_model = Mock()
89+
mock_info = Mock()
90+
91+
# Mock the post_training_quantization method
92+
wrapper._post_training_quantization = Mock(
93+
return_value=(mock_quantized_model, mock_info))
94+
wrapper.export_model = Mock()
95+
wrapper._setting_PTQparam = mock_setting_ptq
96+
97+
mock_setting_ptq.return_value = {'mock': 'params'}
98+
99+
param_items = [('n_epochs', 10, 'Test parameter')]
100+
101+
# Call the method
102+
success, result_model = wrapper.quantize_and_export(
103+
float_model=mock_float_model,
104+
method='PTQ',
105+
framework='tensorflow',
106+
use_MCT_TPC=True,
107+
use_MixP=False,
108+
representative_dataset=mock_representative_dataset,
109+
param_items=param_items
110+
)
111+
112+
# Verify the flow
113+
assert wrapper.float_model == mock_float_model
114+
assert wrapper.framework == 'tensorflow'
115+
assert wrapper.representative_dataset == mock_representative_dataset
116+
117+
mock_get_tpc.assert_called_once_with()
118+
mock_select_method.assert_called_once_with()
119+
mock_select_argname.assert_called_once_with()
120+
mock_setting_ptq.assert_called_once()
121+
wrapper._post_training_quantization.assert_called_once_with(
122+
**{'mock': 'params'})
123+
mock_export.assert_called_once_with(mock_quantized_model)
124+
125+
assert success is True
126+
assert result_model == mock_quantized_model
127+
128+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
129+
'MCTWrapper._get_TPC')
130+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
131+
'MCTWrapper._select_method')
132+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
133+
'MCTWrapper.select_argname')
134+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
135+
'MCTWrapper._setting_GPTQ_MixP')
136+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
137+
'MCTWrapper._export_model')
138+
def test_quantize_and_export_GPTQ_MixP_flow(
139+
self, mock_export: Mock, mock_setting_gptq_mixp: Mock,
140+
mock_select_argname: Mock, mock_select_method: Mock,
141+
mock_get_tpc: Mock) -> None:
142+
"""Test complete quantize_and_export flow for GPTQ with MixP"""
143+
wrapper = MCTWrapper()
144+
145+
# Setup mocks
146+
mock_float_model = Mock()
147+
mock_representative_dataset = Mock()
148+
mock_quantized_model = Mock()
149+
mock_info = Mock()
150+
151+
wrapper._post_training_quantization = Mock(
152+
return_value=(mock_quantized_model, mock_info))
153+
wrapper.export_model = Mock()
154+
wrapper._setting_PTQparam = mock_setting_gptq_mixp
155+
156+
mock_setting_gptq_mixp.return_value = {'mock': 'gptq_params'}
157+
158+
# Call the method
159+
success, result_model = wrapper.quantize_and_export(
160+
float_model=mock_float_model,
161+
method='GPTQ',
162+
framework='tensorflow',
163+
use_MCT_TPC=True,
164+
use_MixP=True,
165+
representative_dataset=mock_representative_dataset,
166+
param_items=[]
167+
)
168+
169+
# Verify the flow
170+
mock_get_tpc.assert_called_once_with()
171+
mock_select_method.assert_called_once_with()
172+
mock_select_argname.assert_called_once_with()
173+
mock_setting_gptq_mixp.assert_called_once()
174+
wrapper._post_training_quantization.assert_called_once_with(
175+
**{'mock': 'gptq_params'})
176+
mock_export.assert_called_once_with(mock_quantized_model)
177+
178+
assert success is True
179+
assert result_model == mock_quantized_model
180+
181+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
182+
'MCTWrapper._exec_lq_ptq')
183+
def test_quantize_and_export_LQPTQ(self, mock_exec_lq_ptq: Mock) -> None:
184+
"""Test quantize_and_export flow for LQ-PTQ with TensorFlow"""
185+
wrapper = MCTWrapper()
186+
187+
mock_float_model = Mock()
188+
mock_representative_dataset = Mock()
189+
mock_quantized_model = Mock()
190+
191+
mock_exec_lq_ptq.return_value = mock_quantized_model
192+
193+
# Call the method
194+
success, result_model = wrapper.quantize_and_export(
195+
float_model=mock_float_model,
196+
method='LQPTQ',
197+
framework='tensorflow',
198+
use_MCT_TPC=True,
199+
use_MixP=False,
200+
representative_dataset=mock_representative_dataset,
201+
param_items=[]
202+
)
203+
204+
# Verify the flow
205+
mock_exec_lq_ptq.assert_called_once()
206+
assert success is True
207+
assert result_model == mock_quantized_model
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Sony Semiconductor Solutions. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================

0 commit comments

Comments
 (0)