Skip to content

Commit 1bc1a36

Browse files
Change to import only individual constants
1 parent 5b02db4 commit 1bc1a36

File tree

1 file changed

+54
-61
lines changed

1 file changed

+54
-61
lines changed

model_compression_toolkit/wrapper/mct_wrapper.py

Lines changed: 54 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,19 @@
1616
import os
1717
from typing import Dict, Any, List, Tuple, Optional, Union, Callable
1818
import model_compression_toolkit as mct
19-
from model_compression_toolkit.wrapper import constants as wrapper_const
2019
from model_compression_toolkit.logger import Logger
20+
from model_compression_toolkit.wrapper.constants import (
21+
REPRESENTATIVE_DATA_GEN, CORE_CONFIG, FW_NAME, TARGET_PLATFORM_VERSION,
22+
TARGET_PLATFORM_NAME, TPC_VERSION, DEVICE_TYPE, EXTENDED_VERSION,
23+
NUM_OF_IMAGES, USE_HESSIAN_BASED_SCORES, IN_MODEL, IN_MODULE, MODEL,
24+
TARGET_PLATFORM_CAPABILITIES, TARGET_RESOURCE_UTILIZATION,
25+
ACTIVATION_ERROR_METHOD, WEIGHTS_ERROR_METHOD, WEIGHTS_BIAS_CORRECTION,
26+
Z_THRESHOLD, LINEAR_COLLAPSING, RESIDUAL_COLLAPSING, GPTQ_CONFIG
27+
)
2128
#import low_bit_quantizer_ptq.ptq as lq_ptq
2229

30+
31+
2332
import importlib
2433
FOUND_TPC = importlib.util.find_spec("edgemdt_tpc") is not None
2534
FOUND_TPC = False
@@ -223,14 +232,14 @@ def select_argname(self) -> None:
223232
calling any _setting_* methods that use these parameter names.
224233
"""
225234
if self.framework == 'tensorflow':
226-
self.argname_in_module = wrapper_const.IN_MODEL
235+
self.argname_in_module = IN_MODEL
227236
elif self.framework == 'pytorch':
228-
self.argname_in_module = wrapper_const.IN_MODULE
237+
self.argname_in_module = IN_MODULE
229238

230239
if self.framework == 'tensorflow':
231-
self.argname_model = wrapper_const.IN_MODEL
240+
self.argname_model = IN_MODEL
232241
elif self.framework == 'pytorch':
233-
self.argname_model = wrapper_const.MODEL
242+
self.argname_model = MODEL
234243

235244
def _get_TPC(self) -> None:
236245
"""
@@ -245,19 +254,19 @@ def _get_TPC(self) -> None:
245254
if self.use_MCT_TPC:
246255
# Use MCT's built-in TPC configuration
247256
params_TPC = {
248-
wrapper_const.FW_NAME: self.params['fw_name'],
249-
wrapper_const.TARGET_PLATFORM_NAME: 'imx500',
250-
wrapper_const.TARGET_PLATFORM_VERSION: self.params['target_platform_version'],
257+
FW_NAME: self.params['fw_name'],
258+
TARGET_PLATFORM_NAME: 'imx500',
259+
TARGET_PLATFORM_VERSION: self.params['target_platform_version'],
251260
}
252261
# Get TPC from MCT framework
253262
self.tpc = mct.get_target_platform_capabilities(**params_TPC)
254263
else:
255264
if FOUND_TPC:
256265
# Use external EdgeMDT TPC configuration
257266
params_TPC = {
258-
wrapper_const.TPC_VERSION: self.params['tpc_version'],
259-
wrapper_const.DEVICE_TYPE: 'imx500',
260-
wrapper_const.EXTENDED_VERSION: None
267+
TPC_VERSION: self.params['tpc_version'],
268+
DEVICE_TYPE: 'imx500',
269+
EXTENDED_VERSION: None
261270
}
262271
# Get TPC from EdgeMDT framework
263272
self.tpc = edgemdt_tpc.get_target_platform_capabilities(**params_TPC)
@@ -272,16 +281,16 @@ def _setting_PTQ_MixP(self) -> Dict[str, Any]:
272281
dict: Parameter dictionary for PTQ.
273282
"""
274283
params_MPCfg = {
275-
wrapper_const.NUM_OF_IMAGES: self.params['num_of_images'],
276-
wrapper_const.USE_HESSIAN_BASED_SCORES: self.params['use_hessian_based_scores'],
284+
NUM_OF_IMAGES: self.params['num_of_images'],
285+
USE_HESSIAN_BASED_SCORES: self.params['use_hessian_based_scores'],
277286
}
278287
mixed_precision_config = mct.core.MixedPrecisionQuantizationConfig(**params_MPCfg)
279288
core_config = mct.core.CoreConfig(mixed_precision_config=mixed_precision_config)
280289
params_RUDCfg = {
281-
wrapper_const.IN_MODEL: self.float_model,
282-
wrapper_const.REPRESENTATIVE_DATA_GEN: self.representative_dataset,
283-
wrapper_const.CORE_CONFIG: core_config,
284-
wrapper_const.TARGET_PLATFORM_CAPABILITIES: self.tpc
290+
IN_MODEL: self.float_model,
291+
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
292+
CORE_CONFIG: core_config,
293+
TARGET_PLATFORM_CAPABILITIES: self.tpc
285294
}
286295
ru_data = self.resource_utilization_data(**params_RUDCfg)
287296
weights_compression_ratio = (
@@ -292,10 +301,10 @@ def _setting_PTQ_MixP(self) -> Dict[str, Any]:
292301

293302
params_PTQ = {
294303
self.argname_in_module: self.float_model,
295-
wrapper_const.REPRESENTATIVE_DATA_GEN: self.representative_dataset,
296-
wrapper_const.TARGET_RESOURCE_UTILIZATION: resource_utilization,
297-
wrapper_const.CORE_CONFIG: core_config,
298-
wrapper_const.TARGET_PLATFORM_CAPABILITIES: self.tpc
304+
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
305+
TARGET_RESOURCE_UTILIZATION: resource_utilization,
306+
CORE_CONFIG: core_config,
307+
TARGET_PLATFORM_CAPABILITIES: self.tpc
299308
}
300309
return params_PTQ
301310

@@ -307,23 +316,23 @@ def _setting_PTQ(self) -> Dict[str, Any]:
307316
dict: Parameter dictionary for PTQ.
308317
"""
309318
params_QCfg = {
310-
wrapper_const.ACTIVATION_ERROR_METHOD: self.params['activation_error_method'],
311-
wrapper_const.WEIGHTS_ERROR_METHOD: mct.core.QuantizationErrorMethod.MSE,
312-
wrapper_const.WEIGHTS_BIAS_CORRECTION: self.params['weights_bias_correction'],
313-
wrapper_const.Z_THRESHOLD: self.params['z_threshold'],
314-
wrapper_const.LINEAR_COLLAPSING: self.params['linear_collapsing'],
315-
wrapper_const.RESIDUAL_COLLAPSING: self.params['residual_collapsing']
319+
ACTIVATION_ERROR_METHOD: self.params['activation_error_method'],
320+
WEIGHTS_ERROR_METHOD: mct.core.QuantizationErrorMethod.MSE,
321+
WEIGHTS_BIAS_CORRECTION: self.params['weights_bias_correction'],
322+
Z_THRESHOLD: self.params['z_threshold'],
323+
LINEAR_COLLAPSING: self.params['linear_collapsing'],
324+
RESIDUAL_COLLAPSING: self.params['residual_collapsing']
316325
}
317326
q_config = mct.core.QuantizationConfig(**params_QCfg)
318327
core_config = mct.core.CoreConfig(quantization_config=q_config)
319328
resource_utilization = None
320329

321330
params_PTQ = {
322331
self.argname_in_module: self.float_model,
323-
wrapper_const.REPRESENTATIVE_DATA_GEN: self.representative_dataset,
324-
wrapper_const.TARGET_RESOURCE_UTILIZATION: resource_utilization,
325-
wrapper_const.CORE_CONFIG: core_config,
326-
wrapper_const.TARGET_PLATFORM_CAPABILITIES: self.tpc
332+
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
333+
TARGET_RESOURCE_UTILIZATION: resource_utilization,
334+
CORE_CONFIG: core_config,
335+
TARGET_PLATFORM_CAPABILITIES: self.tpc
327336
}
328337
return params_PTQ
329338

@@ -341,16 +350,16 @@ def _setting_GPTQ_MixP(self) -> Dict[str, Any]:
341350
gptq_config = self.get_gptq_config(**params_GPTQCfg)
342351

343352
params_MPCfg = {
344-
wrapper_const.NUM_OF_IMAGES: self.params['num_of_images'],
345-
wrapper_const.USE_HESSIAN_BASED_SCORES: self.params['use_hessian_based_scores'],
353+
NUM_OF_IMAGES: self.params['num_of_images'],
354+
USE_HESSIAN_BASED_SCORES: self.params['use_hessian_based_scores'],
346355
}
347356
mixed_precision_config = mct.core.MixedPrecisionQuantizationConfig(**params_MPCfg)
348357
core_config = mct.core.CoreConfig(mixed_precision_config=mixed_precision_config)
349358
params_RUDCfg = {
350-
wrapper_const.IN_MODEL: self.float_model,
351-
wrapper_const.REPRESENTATIVE_DATA_GEN: self.representative_dataset,
352-
wrapper_const.CORE_CONFIG: core_config,
353-
wrapper_const.TARGET_PLATFORM_CAPABILITIES: self.tpc
359+
IN_MODEL: self.float_model,
360+
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
361+
CORE_CONFIG: core_config,
362+
TARGET_PLATFORM_CAPABILITIES: self.tpc
354363
}
355364
ru_data = self.resource_utilization_data(**params_RUDCfg)
356365
weights_compression_ratio = (
@@ -366,11 +375,11 @@ def _setting_GPTQ_MixP(self) -> Dict[str, Any]:
366375

367376
params_GPTQ = {
368377
self.argname_model: self.float_model,
369-
wrapper_const.REPRESENTATIVE_DATA_GEN: self.representative_dataset,
370-
wrapper_const.TARGET_RESOURCE_UTILIZATION: resource_utilization,
371-
wrapper_const.GPTQ_CONFIG: gptq_config,
372-
wrapper_const.CORE_CONFIG: core_config,
373-
wrapper_const.TARGET_PLATFORM_CAPABILITIES: self.tpc
378+
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
379+
TARGET_RESOURCE_UTILIZATION: resource_utilization,
380+
GPTQ_CONFIG: gptq_config,
381+
CORE_CONFIG: core_config,
382+
TARGET_PLATFORM_CAPABILITIES: self.tpc
374383
}
375384
return params_GPTQ
376385

@@ -389,9 +398,9 @@ def _setting_GPTQ(self) -> Dict[str, Any]:
389398

390399
params_GPTQ = {
391400
self.argname_model: self.float_model,
392-
wrapper_const.REPRESENTATIVE_DATA_GEN: self.representative_dataset,
393-
wrapper_const.GPTQ_CONFIG: gptq_config,
394-
wrapper_const.TARGET_PLATFORM_CAPABILITIES: self.tpc
401+
REPRESENTATIVE_DATA_GEN: self.representative_dataset,
402+
GPTQ_CONFIG: gptq_config,
403+
TARGET_PLATFORM_CAPABILITIES: self.tpc
395404
}
396405
return params_GPTQ
397406

@@ -408,22 +417,6 @@ def _exec_lq_ptq(self) -> Any:
408417
Note:
409418
This method requires the lq_ptq module to be imported.
410419
"""
411-
model_save_dir, output_file_name = os.path.split(
412-
self.params['save_model_path'])
413-
414-
# Note: lq_ptq module should be imported when using this method
415-
# q_model = lq_ptq.low_bit_quantizer_ptq(
416-
# fp_model=self.float_model,
417-
# representative_dataset=self.representative_dataset,
418-
# model_save_dir=model_save_dir,
419-
# output_file_name=output_file_name,
420-
# learning_rate=self.params['learning_rate'],
421-
# converter_ver=self.params['converter_ver'],
422-
# debug_level='INFO',
423-
# debug_detail=False,
424-
# overwrite_output_file=True)
425-
# return q_model
426-
427420
# Placeholder implementation - replace with actual lq_ptq call
428421
raise NotImplementedError(
429422
"LQ-PTQ functionality requires lq_ptq module to be imported")

0 commit comments

Comments
 (0)