Skip to content

Commit ab9c8a9

Browse files
change the filename to mct_wrapper.py
1 parent 7029767 commit ab9c8a9

File tree

5 files changed

+22
-22
lines changed

5 files changed

+22
-22
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from model_compression_toolkit.wrapper.mctwrapper import MCTWrapper
1+
from model_compression_toolkit.wrapper.mct_wrapper import MCTWrapper
File renamed without changes.

tests_pytest/wrapper_tests/test_mctwrapper.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Test cases for MCTWrapper class from model_compression_toolkit.wrapper.mctwrapper
2+
Test cases for MCTWrapper class from model_compression_toolkit.wrapper.mct_wrapper
33
"""
44

55
import pytest
@@ -13,7 +13,7 @@
1313
print(sys.path)
1414

1515
from model_compression_toolkit.core import QuantizationErrorMethod
16-
from model_compression_toolkit.wrapper.mctwrapper import MCTWrapper
16+
from model_compression_toolkit.wrapper.mct_wrapper import MCTWrapper
1717

1818

1919
class TestMCTWrapper:
@@ -106,7 +106,7 @@ def test_modify_params_non_existing_keys(self) -> None:
106106
assert 'non_existing_key' not in wrapper.params
107107
assert 'another_fake_key' not in wrapper.params
108108

109-
@patch('model_compression_toolkit.wrapper.mctwrapper.mct.get_target_platform_capabilities')
109+
@patch('model_compression_toolkit.wrapper.mct_wrapper.mct.get_target_platform_capabilities')
110110
def test_get_TPC_with_MCT_TPC(self, mock_mct_get_tpc: Mock) -> None:
111111
"""
112112
Test _get_TPC method when using MCT TPC.
@@ -134,7 +134,7 @@ def test_get_TPC_with_MCT_TPC(self, mock_mct_get_tpc: Mock) -> None:
134134
mock_mct_get_tpc.assert_called_once_with(**expected_params)
135135
assert wrapper.tpc == mock_tpc
136136

137-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
137+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
138138
'edgemdt_tpc.get_target_platform_capabilities')
139139
def test_get_TPC_without_MCT_TPC(self, mock_edgemdt_get_tpc: Mock) -> None:
140140
"""
@@ -423,13 +423,13 @@ class TestMCTWrapperIntegration:
423423
- LQ-PTQ TensorFlow: Low-bit quantization specific to TensorFlow
424424
"""
425425

426-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
426+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
427427
'MCTWrapper._get_TPC')
428-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
428+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
429429
'MCTWrapper._select_method')
430-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
430+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
431431
'MCTWrapper._setting_PTQ')
432-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
432+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
433433
'MCTWrapper._export_model')
434434
def test_quantize_and_export_PTQ_flow(
435435
self, mock_export: Mock, mock_setting_ptq: Mock,
@@ -485,13 +485,13 @@ def test_quantize_and_export_PTQ_flow(
485485
assert success is True
486486
assert result_model == mock_quantized_model
487487

488-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
488+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
489489
'MCTWrapper._get_TPC')
490-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
490+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
491491
'MCTWrapper._select_method')
492-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
492+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
493493
'MCTWrapper._setting_GPTQ_MixP')
494-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
494+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
495495
'MCTWrapper._export_model')
496496
def test_quantize_and_export_GPTQ_MixP_flow(
497497
self, mock_export: Mock, mock_setting_gptq_mixp: Mock,
@@ -541,7 +541,7 @@ def test_quantize_and_export_GPTQ_MixP_flow(
541541
assert success is True
542542
assert result_model == mock_quantized_model
543543

544-
@patch('model_compression_toolkit.wrapper.mctwrapper.'
544+
@patch('model_compression_toolkit.wrapper.mct_wrapper.'
545545
'MCTWrapper._exec_lq_ptq')
546546
def test_quantize_and_export_LQPTQ_tensorflow(self, mock_exec_lq_ptq: Mock) -> None:
547547
"""

tests_pytest/wrapper_tests/test_wrap_keras_E2E.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def PTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
207207
['save_model_path', './qmodel_PTQ_Keras.tflite', 'Path to save the model.']]
208208

209209
# Execute quantization using MCTWrapper
210-
wrapper = mct.wrapper.mctwrapper.MCTWrapper()
210+
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
211211
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_MixP, representative_dataset_gen, param_items)
212212
return flag, quantized_model
213213

@@ -237,7 +237,7 @@ def PTQ_Keras_MixP(float_model: keras.Model) -> Tuple[bool, keras.Model]:
237237
['save_model_path', './qmodel_PTQ_Keras_MixP.tflite', 'Path to save the model.']]
238238

239239
# Execute quantization with mixed precision using MCTWrapper
240-
wrapper = mct.wrapper.mctwrapper.MCTWrapper()
240+
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
241241
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_MixP, representative_dataset_gen, param_items)
242242
return flag, quantized_model
243243

@@ -266,7 +266,7 @@ def GPTQ_Keras(float_model: keras.Model) -> Tuple[bool, keras.Model]:
266266
['save_model_path', './qmodel_GPTQ_Keras.tflite', 'Path to save the model.']]
267267

268268
# Execute gradient-based quantization using MCTWrapper
269-
wrapper = mct.wrapper.mctwrapper.MCTWrapper()
269+
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
270270
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_MixP, representative_dataset_gen, param_items)
271271
return flag, quantized_model
272272

@@ -291,7 +291,7 @@ def GPTQ_Keras_MixP(float_model: keras.Model) -> Tuple[bool, keras.Model]:
291291

292292
['save_model_path', './qmodel_GPTQ_Keras_MixP.tflite', 'Path to save the model.']]
293293

294-
wrapper = mct.wrapper.mctwrapper.MCTWrapper()
294+
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
295295
flag, quantized_model = wrapper.quantize_and_export(float_model, method, framework, use_MCT_TPC, use_MixP, representative_dataset_gen, param_items)
296296
return flag, quantized_model
297297

tests_pytest/wrapper_tests/test_wrap_pytorch_E2E.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def PTQ_Pytorch(float_model):
282282
]
283283

284284
# Execute quantization using MCTWrapper and export to ONNX
285-
wrapper = mct.wrapper.mctwrapper.MCTWrapper()
285+
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
286286
flag, quantized_model = wrapper.quantize_and_export(
287287
float_model, method, framework, use_MCT_TPC, use_MixP,
288288
representative_dataset_gen, param_items)
@@ -324,7 +324,7 @@ def PTQ_Pytorch_MixP(float_model):
324324
]
325325

326326
# Execute mixed precision quantization and export to ONNX
327-
wrapper = mct.wrapper.mctwrapper.MCTWrapper()
327+
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
328328
flag, quantized_model = wrapper.quantize_and_export(
329329
float_model, method, framework, use_MCT_TPC, use_MixP,
330330
representative_dataset_gen, param_items)
@@ -366,7 +366,7 @@ def GPTQ_Pytorch(float_model):
366366
]
367367

368368
# Execute gradient-based quantization and export to ONNX
369-
wrapper = mct.wrapper.mctwrapper.MCTWrapper()
369+
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
370370
flag, quantized_model = wrapper.quantize_and_export(
371371
float_model, method, framework, use_MCT_TPC, use_MixP,
372372
representative_dataset_gen, param_items)
@@ -410,7 +410,7 @@ def GPTQ_Pytorch_MixP(float_model):
410410
]
411411

412412
# Execute advanced GPTQ with mixed precision and export to ONNX
413-
wrapper = mct.wrapper.mctwrapper.MCTWrapper()
413+
wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
414414
flag, quantized_model = wrapper.quantize_and_export(
415415
float_model, method, framework, use_MCT_TPC, use_MixP,
416416
representative_dataset_gen, param_items)

0 commit comments

Comments
 (0)