Skip to content

Commit 1fc08d7

Browse files
Convert a string to a constant
1 parent 26b4435 commit 1fc08d7

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

model_compression_toolkit/wrapper/constants.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,10 @@
5050

5151
# GPTQ parameters
5252
N_EPOCHS = 'n_epochs'
53-
OPTIMIZER = 'optimizer'
53+
OPTIMIZER = 'optimizer'
54+
55+
# Export parameters
56+
CONVERTER_VER = 'converter_ver'
57+
LEARNING_RATE = 'learning_rate'
58+
CALLBACK = 'callback'
59+
SAVE_MODEL_PATH = 'save_model_path'

model_compression_toolkit/wrapper/mct_wrapper.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
TARGET_PLATFORM_CAPABILITIES, TARGET_RESOURCE_UTILIZATION,
2525
ACTIVATION_ERROR_METHOD, WEIGHTS_ERROR_METHOD, WEIGHTS_BIAS_CORRECTION,
2626
Z_THRESHOLD, LINEAR_COLLAPSING, RESIDUAL_COLLAPSING, GPTQ_CONFIG,
27-
WEIGHTS_COMPRESSION_RATIO, N_EPOCHS, OPTIMIZER
27+
WEIGHTS_COMPRESSION_RATIO, N_EPOCHS, OPTIMIZER, LEARNING_RATE,
28+
CONVERTER_VER, CALLBACK, SAVE_MODEL_PATH
2829
)
2930

3031

32+
33+
3134
class MCTWrapper:
3235
"""
3336
Wrapper class for Model Compression Toolkit (MCT) quantization and export.
@@ -58,37 +61,37 @@ def __init__(self):
5861
"""
5962
self.params: Dict[str, Any] = {
6063
# TPC
61-
'fw_name': 'pytorch',
62-
'target_platform_version': 'v1',
63-
'tpc_version': '5.0',
64+
FW_NAME: 'pytorch',
65+
TARGET_PLATFORM_VERSION: 'v1',
66+
TPC_VERSION: '5.0',
6467

6568
# QuantizationConfig
66-
'activation_error_method': mct.core.QuantizationErrorMethod.MSE,
67-
'weights_bias_correction': True,
68-
'z_threshold': float('inf'),
69-
'linear_collapsing': True,
70-
'residual_collapsing': True,
69+
ACTIVATION_ERROR_METHOD: mct.core.QuantizationErrorMethod.MSE,
70+
WEIGHTS_BIAS_CORRECTION: True,
71+
Z_THRESHOLD: float('inf'),
72+
LINEAR_COLLAPSING: True,
73+
RESIDUAL_COLLAPSING: True,
7174

7275
# GradientPTQConfig
73-
'n_epochs': 5,
74-
'optimizer': None,
76+
N_EPOCHS: 5,
77+
OPTIMIZER: None,
7578

7679
# MixedPrecisionQuantizationConfig
77-
'num_of_images': 5,
78-
'use_hessian_based_scores': False,
80+
NUM_OF_IMAGES: 5,
81+
USE_HESSIAN_BASED_SCORES: False,
7982

8083
# ResourceUtilization
81-
'weights_compression_ratio': None,
84+
WEIGHTS_COMPRESSION_RATIO: None,
8285

8386
# low_bit_quantizer_ptq
84-
'learning_rate': 0.001,
85-
'converter_ver': 'v3.14',
87+
LEARNING_RATE: 0.001,
88+
CONVERTER_VER: 'v3.14',
8689

8790
# Export
88-
'save_model_path': './qmodel.onnx',
89-
91+
SAVE_MODEL_PATH: './qmodel.onnx',
92+
9093
# Callback function
91-
'callback': None
94+
CALLBACK: None
9295
}
9396

9497
def _initialize_and_validate(self, float_model: Any, method: str = 'PTQ',

0 commit comments

Comments
 (0)