Skip to content

Commit a04156f

Browse files
Refactor MCTWrapper and related components for enhanced quantization configuration
1 parent 27a5e88 commit a04156f

File tree

9 files changed

+255
-105
lines changed

9 files changed

+255
-105
lines changed

model_compression_toolkit/verify_packages.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
1615
import importlib
1716
from packaging import version
1817

model_compression_toolkit/wrapper/constants.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,20 @@
1717
FW_NAME = 'fw_name'
1818
SDSP_VERSION = 'sdsp_version'
1919

20+
# QuantizationConfig parameters
21+
ACTIVATION_ERROR_METHOD = 'activation_error_method'
22+
WEIGHTS_BIAS_CORRECTION = 'weights_bias_correction'
23+
Z_THRESHOLD = 'z_threshold'
24+
LINEAR_COLLAPSING = 'linear_collapsing'
25+
RESIDUAL_COLLAPSING = 'residual_collapsing'
26+
WEIGHTS_ERROR_METHOD = 'weights_error_method'
27+
2028
# MixedPrecisionQuantizationConfig parameters
29+
DISTANCE_WEIGHTING_METHOD = 'distance_weighting_method'
2130
NUM_OF_IMAGES = 'num_of_images'
2231
USE_HESSIAN_BASED_SCORES = 'use_hessian_based_scores'
32+
33+
# ResourceUtilization parameters
2334
WEIGHTS_COMPRESSION_RATIO = 'weights_compression_ratio'
2435

2536
# Resource utilization data parameters
@@ -32,14 +43,6 @@
3243
TARGET_RESOURCE_UTILIZATION = 'target_resource_utilization'
3344
IN_MODULE = 'in_module'
3445

35-
# QuantizationConfig parameters
36-
ACTIVATION_ERROR_METHOD = 'activation_error_method'
37-
WEIGHTS_ERROR_METHOD = 'weights_error_method'
38-
WEIGHTS_BIAS_CORRECTION = 'weights_bias_correction'
39-
Z_THRESHOLD = 'z_threshold'
40-
LINEAR_COLLAPSING = 'linear_collapsing'
41-
RESIDUAL_COLLAPSING = 'residual_collapsing'
42-
4346
# GPTQ specific parameters
4447
GPTQ_CONFIG = 'gptq_config'
4548
MODEL = 'model'
@@ -48,7 +51,12 @@
4851
N_EPOCHS = 'n_epochs'
4952
OPTIMIZER = 'optimizer'
5053

51-
# Export parameters
54+
# low_bit_quantizer_ptq
5255
CONVERTER_VER = 'converter_ver'
5356
LEARNING_RATE = 'learning_rate'
57+
58+
# Export parameters
5459
SAVE_MODEL_PATH = 'save_model_path'
60+
61+
# default compression ratio
62+
DEFAULT_COMPRESSION_RATIO = 0.75

model_compression_toolkit/wrapper/mct_wrapper.py

Lines changed: 93 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
import model_compression_toolkit as mct
1818
from model_compression_toolkit.logger import Logger
1919
from model_compression_toolkit.wrapper.constants import (
20-
REPRESENTATIVE_DATA_GEN, CORE_CONFIG, FW_NAME, SDSP_VERSION,
21-
NUM_OF_IMAGES, USE_HESSIAN_BASED_SCORES, IN_MODEL, IN_MODULE, MODEL,
22-
TARGET_PLATFORM_CAPABILITIES, TARGET_RESOURCE_UTILIZATION,
23-
ACTIVATION_ERROR_METHOD, WEIGHTS_ERROR_METHOD, WEIGHTS_BIAS_CORRECTION,
24-
Z_THRESHOLD, LINEAR_COLLAPSING, RESIDUAL_COLLAPSING, GPTQ_CONFIG,
25-
WEIGHTS_COMPRESSION_RATIO, N_EPOCHS, OPTIMIZER, LEARNING_RATE,
26-
CONVERTER_VER, SAVE_MODEL_PATH
20+
FW_NAME, SDSP_VERSION, ACTIVATION_ERROR_METHOD, WEIGHTS_BIAS_CORRECTION,
21+
Z_THRESHOLD, LINEAR_COLLAPSING, RESIDUAL_COLLAPSING, WEIGHTS_ERROR_METHOD,
22+
DISTANCE_WEIGHTING_METHOD, NUM_OF_IMAGES,
23+
USE_HESSIAN_BASED_SCORES, WEIGHTS_COMPRESSION_RATIO,
24+
IN_MODEL, REPRESENTATIVE_DATA_GEN, CORE_CONFIG, TARGET_PLATFORM_CAPABILITIES,
25+
TARGET_RESOURCE_UTILIZATION, IN_MODULE, GPTQ_CONFIG, MODEL,
26+
N_EPOCHS, OPTIMIZER, LEARNING_RATE, CONVERTER_VER, SAVE_MODEL_PATH, DEFAULT_COMPRESSION_RATIO
2727
)
2828

2929

@@ -55,11 +55,11 @@ def __init__(self):
5555
:widths: 30, 30, 40
5656
5757
"sdsp_version", "'3.14'", "SDSP version for TPC"
58-
"activation_error_method", "mct.core.QuantizationErrorMethod.MSE", "Activation quantization error method"
59-
"weights_bias_correction", "True", "Enable weights bias correction"
60-
"z_threshold", "float('inf')", "Z-threshold for quantization"
61-
"linear_collapsing", "True", "Enable linear layer collapsing"
62-
"residual_collapsing", "True", "Enable residual connection collapsing"
58+
"activation_error_method", "mct.core.QuantizationErrorMethod.MSE", "Activation quantization error method (low priority)"
59+
"weights_bias_correction", "True", "Enable weights bias correction (low priority)"
60+
"z_threshold", "float('inf')", "Z-threshold for quantization (low priority)"
61+
"linear_collapsing", "True", "Enable linear layer collapsing (low priority)"
62+
"residual_collapsing", "True", "Enable residual connection collapsing (low priority)"
6363
"save_model_path", "'./qmodel.keras' / './qmodel.onnx'", "Path to save quantized model (Keras/Pytorch)"
6464
6565
**PTQ, mixed_precision**
@@ -69,8 +69,14 @@ def __init__(self):
6969
:widths: 30, 30, 40
7070
7171
"sdsp_version", "'3.14'", "SDSP version for TPC"
72+
"activation_error_method", "mct.core.QuantizationErrorMethod.MSE", "Activation quantization error method (low priority)"
73+
"weights_bias_correction", "True", "Enable weights bias correction (low priority)"
74+
"z_threshold", "float('inf')", "Z-threshold for quantization (low priority)"
75+
"linear_collapsing", "True", "Enable linear layer collapsing (low priority)"
76+
"residual_collapsing", "True", "Enable residual connection collapsing (low priority)"
77+
"distance_weighting_method", "None", "Distance weighting method for mixed precision (low priority)"
7278
"num_of_images", "5", "Number of images for mixed precision"
73-
"use_hessian_based_scores", "False", "Use Hessian-based scores for mixed precision"
79+
"use_hessian_based_scores", "False", "Use Hessian-based scores for mixed precision (low priority)"
7480
"weights_compression_ratio", "None", "Weights compression ratio for resource util"
7581
"save_model_path", "'./qmodel.keras' / './qmodel.onnx'", "Path to save quantized model (Keras/Pytorch)"
7682
@@ -81,8 +87,13 @@ def __init__(self):
8187
:widths: 30, 30, 40
8288
8389
"sdsp_version", "'3.14'", "SDSP version for TPC"
90+
"activation_error_method", "mct.core.QuantizationErrorMethod.MSE", "Activation quantization error method (low priority)"
91+
"weights_bias_correction", "True", "Enable weights bias correction (low priority)"
92+
"z_threshold", "float('inf')", "Z-threshold for quantization (low priority)"
93+
"linear_collapsing", "True", "Enable linear layer collapsing (low priority)"
94+
"residual_collapsing", "True", "Enable residual connection collapsing (low priority)"
8495
"n_epochs", "5", "Number of training epochs for GPTQ"
85-
"optimizer", "None", "Optimizer for GPTQ training"
96+
"optimizer", "None", "Optimizer for GPTQ training (low priority)"
8697
"save_model_path", "'./qmodel.keras' / './qmodel.onnx'", "Path to save quantized model (Keras/Pytorch)"
8798
8899
**GPTQ, mixed_precision**
@@ -92,11 +103,17 @@ def __init__(self):
92103
:widths: 30, 30, 40
93104
94105
"sdsp_version", "'3.14'", "SDSP version for TPC"
106+
"activation_error_method", "mct.core.QuantizationErrorMethod.MSE", "Activation quantization error method (low priority)"
107+
"weights_bias_correction", "True", "Enable weights bias correction (low priority)"
108+
"z_threshold", "float('inf')", "Z-threshold for quantization (low priority)"
109+
"linear_collapsing", "True", "Enable linear layer collapsing (low priority)"
110+
"residual_collapsing", "True", "Enable residual connection collapsing (low priority)"
111+
"weights_compression_ratio", "None", "Weights compression ratio for resource util"
95112
"n_epochs", "5", "Number of training epochs for GPTQ"
96-
"optimizer", "None", "Optimizer for GPTQ training"
113+
"optimizer", "None", "Optimizer for GPTQ training (low priority)"
114+
"distance_weighting_method", "None", "Distance weighting method for GPTQ (low priority)"
97115
"num_of_images", "5", "Number of images for mixed precision"
98116
"use_hessian_based_scores", "False", "Use Hessian-based scores for mixed precision"
99-
"weights_compression_ratio", "None", "Weights compression ratio for resource util"
100117
"save_model_path", "'./qmodel.keras' / './qmodel.onnx'", "Path to save quantized model (Keras/Pytorch)"
101118
102119
"""
@@ -112,17 +129,18 @@ def __init__(self):
112129
LINEAR_COLLAPSING: True,
113130
RESIDUAL_COLLAPSING: True,
114131

115-
# GradientPTQConfig
116-
N_EPOCHS: 5,
117-
OPTIMIZER: None,
118-
119132
# MixedPrecisionQuantizationConfig
133+
DISTANCE_WEIGHTING_METHOD: None,
120134
NUM_OF_IMAGES: 5,
121135
USE_HESSIAN_BASED_SCORES: False,
122136

123137
# ResourceUtilization
124138
WEIGHTS_COMPRESSION_RATIO: None,
125139

140+
# GradientPTQConfig
141+
N_EPOCHS: 5,
142+
OPTIMIZER: None,
143+
126144
# low_bit_quantizer_ptq
127145
LEARNING_RATE: 0.001,
128146
CONVERTER_VER: 'v3.14',
@@ -172,16 +190,21 @@ def _initialize_and_validate(self, float_model: Any,
172190
Z_THRESHOLD, LINEAR_COLLAPSING, RESIDUAL_COLLAPSING,
173191
SAVE_MODEL_PATH]
174192
else:
175-
allowed_keys = [FW_NAME, SDSP_VERSION, NUM_OF_IMAGES, USE_HESSIAN_BASED_SCORES,
193+
allowed_keys = [FW_NAME, SDSP_VERSION, ACTIVATION_ERROR_METHOD, WEIGHTS_BIAS_CORRECTION,
194+
Z_THRESHOLD, LINEAR_COLLAPSING, RESIDUAL_COLLAPSING,
195+
DISTANCE_WEIGHTING_METHOD, NUM_OF_IMAGES, USE_HESSIAN_BASED_SCORES,
176196
WEIGHTS_COMPRESSION_RATIO, SAVE_MODEL_PATH]
177197
else:
178198
if not use_mixed_precision:
179-
allowed_keys = [FW_NAME, SDSP_VERSION, N_EPOCHS, OPTIMIZER,
180-
SAVE_MODEL_PATH]
199+
allowed_keys = [FW_NAME, SDSP_VERSION, ACTIVATION_ERROR_METHOD, WEIGHTS_BIAS_CORRECTION,
200+
Z_THRESHOLD, LINEAR_COLLAPSING, RESIDUAL_COLLAPSING,
201+
N_EPOCHS, OPTIMIZER, SAVE_MODEL_PATH]
181202
else:
182-
allowed_keys = [FW_NAME, SDSP_VERSION, N_EPOCHS, OPTIMIZER,
203+
allowed_keys = [FW_NAME, SDSP_VERSION, ACTIVATION_ERROR_METHOD, WEIGHTS_BIAS_CORRECTION,
204+
Z_THRESHOLD, LINEAR_COLLAPSING, RESIDUAL_COLLAPSING,
205+
WEIGHTS_COMPRESSION_RATIO, N_EPOCHS, OPTIMIZER, DISTANCE_WEIGHTING_METHOD,
183206
NUM_OF_IMAGES, USE_HESSIAN_BASED_SCORES,
184-
WEIGHTS_COMPRESSION_RATIO, SAVE_MODEL_PATH]
207+
SAVE_MODEL_PATH]
185208

186209
self.params = { k: v for k, v in self.params.items() if k in allowed_keys }
187210

@@ -320,12 +343,26 @@ def _setting_PTQ_mixed_precision(self) -> Dict[str, Any]:
320343
Returns:
321344
dict: Parameter dictionary for PTQ.
322345
"""
346+
params_QCfg = {
347+
ACTIVATION_ERROR_METHOD: self.params[ACTIVATION_ERROR_METHOD],
348+
WEIGHTS_ERROR_METHOD: mct.core.QuantizationErrorMethod.MSE,
349+
WEIGHTS_BIAS_CORRECTION: self.params[WEIGHTS_BIAS_CORRECTION],
350+
Z_THRESHOLD: self.params[Z_THRESHOLD],
351+
LINEAR_COLLAPSING: self.params[LINEAR_COLLAPSING],
352+
RESIDUAL_COLLAPSING: self.params[RESIDUAL_COLLAPSING]
353+
}
354+
q_config = mct.core.QuantizationConfig(**params_QCfg)
355+
323356
params_MPCfg = {
357+
DISTANCE_WEIGHTING_METHOD: self.params[DISTANCE_WEIGHTING_METHOD],
324358
NUM_OF_IMAGES: self.params[NUM_OF_IMAGES],
325359
USE_HESSIAN_BASED_SCORES: self.params[USE_HESSIAN_BASED_SCORES]
326360
}
327-
mixed_precision_config = mct.core.MixedPrecisionQuantizationConfig(**params_MPCfg)
328-
core_config = mct.core.CoreConfig(mixed_precision_config=mixed_precision_config)
361+
mixed_precision_config = mct.core.MixedPrecisionQuantizationConfig(**params_MPCfg)
362+
363+
core_config = mct.core.CoreConfig(quantization_config=q_config,
364+
mixed_precision_config=mixed_precision_config)
365+
329366
params_RUDCfg = {
330367
IN_MODEL: self.float_model,
331368
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
@@ -334,7 +371,7 @@ def _setting_PTQ_mixed_precision(self) -> Dict[str, Any]:
334371
}
335372
ru_data = self.resource_utilization_data(**params_RUDCfg)
336373
weights_compression_ratio = (
337-
0.75 if self.params[WEIGHTS_COMPRESSION_RATIO] is None
374+
DEFAULT_COMPRESSION_RATIO if self.params[WEIGHTS_COMPRESSION_RATIO] is None
338375
else self.params[WEIGHTS_COMPRESSION_RATIO])
339376
resource_utilization = mct.core.ResourceUtilization(
340377
ru_data.weights_memory * weights_compression_ratio)
@@ -383,18 +420,32 @@ def _setting_GPTQ_mixed_precision(self) -> Dict[str, Any]:
383420
Returns:
384421
dict: Parameter dictionary for GPTQ.
385422
"""
423+
params_QCfg = {
424+
ACTIVATION_ERROR_METHOD: self.params[ACTIVATION_ERROR_METHOD],
425+
WEIGHTS_ERROR_METHOD: mct.core.QuantizationErrorMethod.MSE,
426+
WEIGHTS_BIAS_CORRECTION: self.params[WEIGHTS_BIAS_CORRECTION],
427+
Z_THRESHOLD: self.params[Z_THRESHOLD],
428+
LINEAR_COLLAPSING: self.params[LINEAR_COLLAPSING],
429+
RESIDUAL_COLLAPSING: self.params[RESIDUAL_COLLAPSING]
430+
}
431+
q_config = mct.core.QuantizationConfig(**params_QCfg)
432+
386433
params_GPTQCfg = {
387434
N_EPOCHS: self.params[N_EPOCHS],
388435
OPTIMIZER: self.params[OPTIMIZER]
389436
}
390437
gptq_config = self.get_gptq_config(**params_GPTQCfg)
391438

392439
params_MPCfg = {
440+
DISTANCE_WEIGHTING_METHOD: self.params[DISTANCE_WEIGHTING_METHOD],
393441
NUM_OF_IMAGES: self.params[NUM_OF_IMAGES],
394442
USE_HESSIAN_BASED_SCORES: self.params[USE_HESSIAN_BASED_SCORES],
395443
}
396444
mixed_precision_config = mct.core.MixedPrecisionQuantizationConfig(**params_MPCfg)
397-
core_config = mct.core.CoreConfig(mixed_precision_config=mixed_precision_config)
445+
446+
core_config = mct.core.CoreConfig(quantization_config=q_config,
447+
mixed_precision_config=mixed_precision_config)
448+
398449
params_RUDCfg = {
399450
IN_MODEL: self.float_model,
400451
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
@@ -403,16 +454,11 @@ def _setting_GPTQ_mixed_precision(self) -> Dict[str, Any]:
403454
}
404455
ru_data = self.resource_utilization_data(**params_RUDCfg)
405456
weights_compression_ratio = (
406-
0.75 if self.params[WEIGHTS_COMPRESSION_RATIO] is None
457+
DEFAULT_COMPRESSION_RATIO if self.params[WEIGHTS_COMPRESSION_RATIO] is None
407458
else self.params[WEIGHTS_COMPRESSION_RATIO])
408459
resource_utilization = mct.core.ResourceUtilization(
409460
ru_data.weights_memory * weights_compression_ratio)
410461

411-
core_config = mct.core.CoreConfig(
412-
mixed_precision_config = mixed_precision_config,
413-
quantization_config = mct.core.QuantizationConfig()
414-
)
415-
416462
params_GPTQ = {
417463
self.argname_model: self.float_model,
418464
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
@@ -430,6 +476,17 @@ def _setting_GPTQ(self) -> Dict[str, Any]:
430476
Returns:
431477
dict: Parameter dictionary for GPTQ.
432478
"""
479+
params_QCfg = {
480+
ACTIVATION_ERROR_METHOD: self.params[ACTIVATION_ERROR_METHOD],
481+
WEIGHTS_ERROR_METHOD: mct.core.QuantizationErrorMethod.MSE,
482+
WEIGHTS_BIAS_CORRECTION: self.params[WEIGHTS_BIAS_CORRECTION],
483+
Z_THRESHOLD: self.params[Z_THRESHOLD],
484+
LINEAR_COLLAPSING: self.params[LINEAR_COLLAPSING],
485+
RESIDUAL_COLLAPSING: self.params[RESIDUAL_COLLAPSING]
486+
}
487+
q_config = mct.core.QuantizationConfig(**params_QCfg)
488+
core_config = mct.core.CoreConfig(quantization_config=q_config)
489+
433490
params_GPTQCfg = {
434491
N_EPOCHS: self.params[N_EPOCHS],
435492
OPTIMIZER: self.params[OPTIMIZER]
@@ -440,6 +497,7 @@ def _setting_GPTQ(self) -> Dict[str, Any]:
440497
self.argname_model: self.float_model,
441498
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
442499
GPTQ_CONFIG: gptq_config,
500+
CORE_CONFIG: core_config,
443501
TARGET_PLATFORM_CAPABILITIES: self.tpc
444502
}
445503
return params_GPTQ

0 commit comments

Comments
 (0)