|
24 | 24 | TARGET_PLATFORM_CAPABILITIES, TARGET_RESOURCE_UTILIZATION, |
25 | 25 | ACTIVATION_ERROR_METHOD, WEIGHTS_ERROR_METHOD, WEIGHTS_BIAS_CORRECTION, |
26 | 26 | Z_THRESHOLD, LINEAR_COLLAPSING, RESIDUAL_COLLAPSING, GPTQ_CONFIG, |
27 | | - WEIGHTS_COMPRESSION_RATIO, N_EPOCHS, OPTIMIZER |
| 27 | + WEIGHTS_COMPRESSION_RATIO, N_EPOCHS, OPTIMIZER, LEARNING_RATE, |
| 28 | + CONVERTER_VER, CALLBACK, SAVE_MODEL_PATH |
28 | 29 | ) |
29 | 30 |
|
30 | 31 |
|
| 32 | + |
| 33 | + |
31 | 34 | class MCTWrapper: |
32 | 35 | """ |
33 | 36 | Wrapper class for Model Compression Toolkit (MCT) quantization and export. |
@@ -58,37 +61,37 @@ def __init__(self): |
58 | 61 | """ |
59 | 62 | self.params: Dict[str, Any] = { |
60 | 63 | # TPC |
61 | | - 'fw_name': 'pytorch', |
62 | | - 'target_platform_version': 'v1', |
63 | | - 'tpc_version': '5.0', |
| 64 | + FW_NAME: 'pytorch', |
| 65 | + TARGET_PLATFORM_VERSION: 'v1', |
| 66 | + TPC_VERSION: '5.0', |
64 | 67 |
|
65 | 68 | # QuantizationConfig |
66 | | - 'activation_error_method': mct.core.QuantizationErrorMethod.MSE, |
67 | | - 'weights_bias_correction': True, |
68 | | - 'z_threshold': float('inf'), |
69 | | - 'linear_collapsing': True, |
70 | | - 'residual_collapsing': True, |
| 69 | + ACTIVATION_ERROR_METHOD: mct.core.QuantizationErrorMethod.MSE, |
| 70 | + WEIGHTS_BIAS_CORRECTION: True, |
| 71 | + Z_THRESHOLD: float('inf'), |
| 72 | + LINEAR_COLLAPSING: True, |
| 73 | + RESIDUAL_COLLAPSING: True, |
71 | 74 |
|
72 | 75 | # GradientPTQConfig |
73 | | - 'n_epochs': 5, |
74 | | - 'optimizer': None, |
| 76 | + N_EPOCHS: 5, |
| 77 | + OPTIMIZER: None, |
75 | 78 |
|
76 | 79 | # MixedPrecisionQuantizationConfig |
77 | | - 'num_of_images': 5, |
78 | | - 'use_hessian_based_scores': False, |
| 80 | + NUM_OF_IMAGES: 5, |
| 81 | + USE_HESSIAN_BASED_SCORES: False, |
79 | 82 |
|
80 | 83 | # ResourceUtilization |
81 | | - 'weights_compression_ratio': None, |
| 84 | + WEIGHTS_COMPRESSION_RATIO: None, |
82 | 85 |
|
83 | 86 | # low_bit_quantizer_ptq |
84 | | - 'learning_rate': 0.001, |
85 | | - 'converter_ver': 'v3.14', |
| 87 | + LEARNING_RATE: 0.001, |
| 88 | + CONVERTER_VER: 'v3.14', |
86 | 89 |
|
87 | 90 | # Export |
88 | | - 'save_model_path': './qmodel.onnx', |
89 | | - |
| 91 | + SAVE_MODEL_PATH: './qmodel.onnx', |
| 92 | + |
90 | 93 | # Callback function |
91 | | - 'callback': None |
| 94 | + CALLBACK: None |
92 | 95 | } |
93 | 96 |
|
94 | 97 | def _initialize_and_validate(self, float_model: Any, method: str = 'PTQ', |
|
0 commit comments