|
1 | 1 | """ |
2 | | -Test cases for MCTWrapper class from model_compression_toolkit.wrapper.mctwrapper |
| 2 | +Test cases for MCTWrapper class from model_compression_toolkit.wrapper.mct_wrapper |
3 | 3 | """ |
4 | 4 |
|
5 | 5 | import pytest |
|
13 | 13 | print(sys.path) |
14 | 14 |
|
15 | 15 | from model_compression_toolkit.core import QuantizationErrorMethod |
16 | | -from model_compression_toolkit.wrapper.mctwrapper import MCTWrapper |
| 16 | +from model_compression_toolkit.wrapper.mct_wrapper import MCTWrapper |
17 | 17 |
|
18 | 18 |
|
19 | 19 | class TestMCTWrapper: |
@@ -106,7 +106,7 @@ def test_modify_params_non_existing_keys(self) -> None: |
106 | 106 | assert 'non_existing_key' not in wrapper.params |
107 | 107 | assert 'another_fake_key' not in wrapper.params |
108 | 108 |
|
109 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.mct.get_target_platform_capabilities') |
| 109 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.mct.get_target_platform_capabilities') |
110 | 110 | def test_get_TPC_with_MCT_TPC(self, mock_mct_get_tpc: Mock) -> None: |
111 | 111 | """ |
112 | 112 | Test _get_TPC method when using MCT TPC. |
@@ -134,7 +134,7 @@ def test_get_TPC_with_MCT_TPC(self, mock_mct_get_tpc: Mock) -> None: |
134 | 134 | mock_mct_get_tpc.assert_called_once_with(**expected_params) |
135 | 135 | assert wrapper.tpc == mock_tpc |
136 | 136 |
|
137 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 137 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
138 | 138 | 'edgemdt_tpc.get_target_platform_capabilities') |
139 | 139 | def test_get_TPC_without_MCT_TPC(self, mock_edgemdt_get_tpc: Mock) -> None: |
140 | 140 | """ |
@@ -423,13 +423,13 @@ class TestMCTWrapperIntegration: |
423 | 423 | - LQ-PTQ TensorFlow: Low-bit quantization specific to TensorFlow |
424 | 424 | """ |
425 | 425 |
|
426 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 426 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
427 | 427 | 'MCTWrapper._get_TPC') |
428 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 428 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
429 | 429 | 'MCTWrapper._select_method') |
430 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 430 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
431 | 431 | 'MCTWrapper._setting_PTQ') |
432 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 432 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
433 | 433 | 'MCTWrapper._export_model') |
434 | 434 | def test_quantize_and_export_PTQ_flow( |
435 | 435 | self, mock_export: Mock, mock_setting_ptq: Mock, |
@@ -485,13 +485,13 @@ def test_quantize_and_export_PTQ_flow( |
485 | 485 | assert success is True |
486 | 486 | assert result_model == mock_quantized_model |
487 | 487 |
|
488 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 488 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
489 | 489 | 'MCTWrapper._get_TPC') |
490 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 490 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
491 | 491 | 'MCTWrapper._select_method') |
492 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 492 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
493 | 493 | 'MCTWrapper._setting_GPTQ_MixP') |
494 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 494 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
495 | 495 | 'MCTWrapper._export_model') |
496 | 496 | def test_quantize_and_export_GPTQ_MixP_flow( |
497 | 497 | self, mock_export: Mock, mock_setting_gptq_mixp: Mock, |
@@ -541,7 +541,7 @@ def test_quantize_and_export_GPTQ_MixP_flow( |
541 | 541 | assert success is True |
542 | 542 | assert result_model == mock_quantized_model |
543 | 543 |
|
544 | | - @patch('model_compression_toolkit.wrapper.mctwrapper.' |
| 544 | + @patch('model_compression_toolkit.wrapper.mct_wrapper.' |
545 | 545 | 'MCTWrapper._exec_lq_ptq') |
546 | 546 | def test_quantize_and_export_LQPTQ_tensorflow(self, mock_exec_lq_ptq: Mock) -> None: |
547 | 547 | """ |
|
0 commit comments