1+ # Copyright 2025 Sony Semiconductor Solutions, Inc. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ # ==============================================================================
15+
16+ """
17+ Test cases for MCTWrapper class from model_compression_toolkit.wrapper.mct_wrapper
18+ """
19+
20+ import pytest
21+ from unittest .mock import Mock , patch
22+ from typing import Any , List , Tuple
23+ from model_compression_toolkit .wrapper .mct_wrapper import MCTWrapper
24+
25+
26+ class TestMCTWrapperIntegration :
27+ """
28+ Integration Tests for MCTWrapper Complete Workflows
29+
30+ This test class focuses on testing the complete quantization and export
31+ workflows by testing the main quantize_and_export method with different
32+ configurations and scenarios.
33+
34+ Test Categories:
35+ - PTQ Workflow: Complete Post-Training Quantization flow
36+ - GPTQ Mixed Precision: Gradient PTQ with mixed precision
37+ - LQ-PTQ TensorFlow: Low-bit quantization specific to TensorFlow
38+ """
39+
40+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
41+ 'MCTWrapper._get_TPC' )
42+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
43+ 'MCTWrapper._select_method' )
44+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
45+ 'MCTWrapper.select_argname' )
46+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
47+ 'MCTWrapper._setting_PTQ' )
48+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
49+ 'MCTWrapper._export_model' )
50+ def test_quantize_and_export_PTQ_flow (
51+ self , mock_export : Mock , mock_setting_ptq : Mock ,
52+ mock_select_argname : Mock , mock_select_method : Mock ,
53+ mock_get_tpc : Mock ) -> None :
54+ """
55+ Test complete quantize_and_export workflow for Post-Training Quantization.
56+
57+ This integration test verifies the complete PTQ workflow from input
58+ validation through model export. It mocks internal methods to focus
59+ on workflow coordination and method call sequences.
60+
61+ Workflow Steps Tested:
62+ 1. Input validation and initialization
63+ 2. Parameter modification
64+ 3. Method selection for framework and quantization type
65+ 4. TPC (Target Platform Capabilities) configuration
66+ 5. PTQ parameter setup
67+ 6. Model quantization execution
68+ 7. Model export
69+
70+ Mocked Components:
71+ - _get_TPC: TPC configuration
72+ - _select_method: Framework-specific method selection
73+ - _Setting_PTQ: PTQ parameter configuration
74+ - _export_model: Model export functionality
75+ - _post_training_quantization: Actual quantization process
76+
77+ Verification Points:
78+ - Correct method call sequence
79+ - Proper parameter passing between methods
80+ - Expected return values (success flag and quantized model)
81+ - Instance state consistency after workflow completion
82+ """
83+ wrapper = MCTWrapper ()
84+
85+ # Setup mocks
86+ mock_float_model = Mock ()
87+ mock_representative_dataset = Mock ()
88+ mock_quantized_model = Mock ()
89+ mock_info = Mock ()
90+
91+ # Mock the post_training_quantization method
92+ wrapper ._post_training_quantization = Mock (
93+ return_value = (mock_quantized_model , mock_info ))
94+ wrapper .export_model = Mock ()
95+ wrapper ._setting_PTQparam = mock_setting_ptq
96+
97+ mock_setting_ptq .return_value = {'mock' : 'params' }
98+
99+ param_items = [('n_epochs' , 10 , 'Test parameter' )]
100+
101+ # Call the method
102+ success , result_model = wrapper .quantize_and_export (
103+ float_model = mock_float_model ,
104+ method = 'PTQ' ,
105+ framework = 'tensorflow' ,
106+ use_MCT_TPC = True ,
107+ use_MixP = False ,
108+ representative_dataset = mock_representative_dataset ,
109+ param_items = param_items
110+ )
111+
112+ # Verify the flow
113+ assert wrapper .float_model == mock_float_model
114+ assert wrapper .framework == 'tensorflow'
115+ assert wrapper .representative_dataset == mock_representative_dataset
116+
117+ mock_get_tpc .assert_called_once_with ()
118+ mock_select_method .assert_called_once_with ()
119+ mock_select_argname .assert_called_once_with ()
120+ mock_setting_ptq .assert_called_once ()
121+ wrapper ._post_training_quantization .assert_called_once_with (
122+ ** {'mock' : 'params' })
123+ mock_export .assert_called_once_with (mock_quantized_model )
124+
125+ assert success is True
126+ assert result_model == mock_quantized_model
127+
128+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
129+ 'MCTWrapper._get_TPC' )
130+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
131+ 'MCTWrapper._select_method' )
132+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
133+ 'MCTWrapper.select_argname' )
134+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
135+ 'MCTWrapper._setting_GPTQ_MixP' )
136+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
137+ 'MCTWrapper._export_model' )
138+ def test_quantize_and_export_GPTQ_MixP_flow (
139+ self , mock_export : Mock , mock_setting_gptq_mixp : Mock ,
140+ mock_select_argname : Mock , mock_select_method : Mock ,
141+ mock_get_tpc : Mock ) -> None :
142+ """Test complete quantize_and_export flow for GPTQ with MixP"""
143+ wrapper = MCTWrapper ()
144+
145+ # Setup mocks
146+ mock_float_model = Mock ()
147+ mock_representative_dataset = Mock ()
148+ mock_quantized_model = Mock ()
149+ mock_info = Mock ()
150+
151+ wrapper ._post_training_quantization = Mock (
152+ return_value = (mock_quantized_model , mock_info ))
153+ wrapper .export_model = Mock ()
154+ wrapper ._setting_PTQparam = mock_setting_gptq_mixp
155+
156+ mock_setting_gptq_mixp .return_value = {'mock' : 'gptq_params' }
157+
158+ # Call the method
159+ success , result_model = wrapper .quantize_and_export (
160+ float_model = mock_float_model ,
161+ method = 'GPTQ' ,
162+ framework = 'tensorflow' ,
163+ use_MCT_TPC = True ,
164+ use_MixP = True ,
165+ representative_dataset = mock_representative_dataset ,
166+ param_items = []
167+ )
168+
169+ # Verify the flow
170+ mock_get_tpc .assert_called_once_with ()
171+ mock_select_method .assert_called_once_with ()
172+ mock_select_argname .assert_called_once_with ()
173+ mock_setting_gptq_mixp .assert_called_once ()
174+ wrapper ._post_training_quantization .assert_called_once_with (
175+ ** {'mock' : 'gptq_params' })
176+ mock_export .assert_called_once_with (mock_quantized_model )
177+
178+ assert success is True
179+ assert result_model == mock_quantized_model
180+
181+ @patch ('model_compression_toolkit.wrapper.mct_wrapper.'
182+ 'MCTWrapper._exec_lq_ptq' )
183+ def test_quantize_and_export_LQPTQ (self , mock_exec_lq_ptq : Mock ) -> None :
184+ """Test quantize_and_export flow for LQ-PTQ with TensorFlow"""
185+ wrapper = MCTWrapper ()
186+
187+ mock_float_model = Mock ()
188+ mock_representative_dataset = Mock ()
189+ mock_quantized_model = Mock ()
190+
191+ mock_exec_lq_ptq .return_value = mock_quantized_model
192+
193+ # Call the method
194+ success , result_model = wrapper .quantize_and_export (
195+ float_model = mock_float_model ,
196+ method = 'LQPTQ' ,
197+ framework = 'tensorflow' ,
198+ use_MCT_TPC = True ,
199+ use_MixP = False ,
200+ representative_dataset = mock_representative_dataset ,
201+ param_items = []
202+ )
203+
204+ # Verify the flow
205+ mock_exec_lq_ptq .assert_called_once ()
206+ assert success is True
207+ assert result_model == mock_quantized_model
0 commit comments