1616import os
1717from typing import Dict , Any , List , Tuple , Optional , Union , Callable
1818import model_compression_toolkit as mct
19- from model_compression_toolkit .wrapper import constants as wrapper_const
2019from model_compression_toolkit .logger import Logger
20+ from model_compression_toolkit .wrapper .constants import (
21+ REPRESENTATIVE_DATA_GEN , CORE_CONFIG , FW_NAME , TARGET_PLATFORM_VERSION ,
22+ TARGET_PLATFORM_NAME , TPC_VERSION , DEVICE_TYPE , EXTENDED_VERSION ,
23+ NUM_OF_IMAGES , USE_HESSIAN_BASED_SCORES , IN_MODEL , IN_MODULE , MODEL ,
24+ TARGET_PLATFORM_CAPABILITIES , TARGET_RESOURCE_UTILIZATION ,
25+ ACTIVATION_ERROR_METHOD , WEIGHTS_ERROR_METHOD , WEIGHTS_BIAS_CORRECTION ,
26+ Z_THRESHOLD , LINEAR_COLLAPSING , RESIDUAL_COLLAPSING , GPTQ_CONFIG
27+ )
2128#import low_bit_quantizer_ptq.ptq as lq_ptq
2229
30+
31+
2332import importlib
2433FOUND_TPC = importlib .util .find_spec ("edgemdt_tpc" ) is not None
2534FOUND_TPC = False
@@ -223,14 +232,14 @@ def select_argname(self) -> None:
223232 calling any _setting_* methods that use these parameter names.
224233 """
225234 if self .framework == 'tensorflow' :
226- self .argname_in_module = wrapper_const . IN_MODEL
235+ self .argname_in_module = IN_MODEL
227236 elif self .framework == 'pytorch' :
228- self .argname_in_module = wrapper_const . IN_MODULE
237+ self .argname_in_module = IN_MODULE
229238
230239 if self .framework == 'tensorflow' :
231- self .argname_model = wrapper_const . IN_MODEL
240+ self .argname_model = IN_MODEL
232241 elif self .framework == 'pytorch' :
233- self .argname_model = wrapper_const . MODEL
242+ self .argname_model = MODEL
234243
235244 def _get_TPC (self ) -> None :
236245 """
@@ -245,19 +254,19 @@ def _get_TPC(self) -> None:
245254 if self .use_MCT_TPC :
246255 # Use MCT's built-in TPC configuration
247256 params_TPC = {
248- wrapper_const . FW_NAME : self .params ['fw_name' ],
249- wrapper_const . TARGET_PLATFORM_NAME : 'imx500' ,
250- wrapper_const . TARGET_PLATFORM_VERSION : self .params ['target_platform_version' ],
257+ FW_NAME : self .params ['fw_name' ],
258+ TARGET_PLATFORM_NAME : 'imx500' ,
259+ TARGET_PLATFORM_VERSION : self .params ['target_platform_version' ],
251260 }
252261 # Get TPC from MCT framework
253262 self .tpc = mct .get_target_platform_capabilities (** params_TPC )
254263 else :
255264 if FOUND_TPC :
256265 # Use external EdgeMDT TPC configuration
257266 params_TPC = {
258- wrapper_const . TPC_VERSION : self .params ['tpc_version' ],
259- wrapper_const . DEVICE_TYPE : 'imx500' ,
260- wrapper_const . EXTENDED_VERSION : None
267+ TPC_VERSION : self .params ['tpc_version' ],
268+ DEVICE_TYPE : 'imx500' ,
269+ EXTENDED_VERSION : None
261270 }
262271 # Get TPC from EdgeMDT framework
263272 self .tpc = edgemdt_tpc .get_target_platform_capabilities (** params_TPC )
@@ -272,16 +281,16 @@ def _setting_PTQ_MixP(self) -> Dict[str, Any]:
272281 dict: Parameter dictionary for PTQ.
273282 """
274283 params_MPCfg = {
275- wrapper_const . NUM_OF_IMAGES : self .params ['num_of_images' ],
276- wrapper_const . USE_HESSIAN_BASED_SCORES : self .params ['use_hessian_based_scores' ],
284+ NUM_OF_IMAGES : self .params ['num_of_images' ],
285+ USE_HESSIAN_BASED_SCORES : self .params ['use_hessian_based_scores' ],
277286 }
278287 mixed_precision_config = mct .core .MixedPrecisionQuantizationConfig (** params_MPCfg )
279288 core_config = mct .core .CoreConfig (mixed_precision_config = mixed_precision_config )
280289 params_RUDCfg = {
281- wrapper_const . IN_MODEL : self .float_model ,
282- wrapper_const . REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
283- wrapper_const . CORE_CONFIG : core_config ,
284- wrapper_const . TARGET_PLATFORM_CAPABILITIES : self .tpc
290+ IN_MODEL : self .float_model ,
291+ REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
292+ CORE_CONFIG : core_config ,
293+ TARGET_PLATFORM_CAPABILITIES : self .tpc
285294 }
286295 ru_data = self .resource_utilization_data (** params_RUDCfg )
287296 weights_compression_ratio = (
@@ -292,10 +301,10 @@ def _setting_PTQ_MixP(self) -> Dict[str, Any]:
292301
293302 params_PTQ = {
294303 self .argname_in_module : self .float_model ,
295- wrapper_const . REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
296- wrapper_const . TARGET_RESOURCE_UTILIZATION : resource_utilization ,
297- wrapper_const . CORE_CONFIG : core_config ,
298- wrapper_const . TARGET_PLATFORM_CAPABILITIES : self .tpc
304+ REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
305+ TARGET_RESOURCE_UTILIZATION : resource_utilization ,
306+ CORE_CONFIG : core_config ,
307+ TARGET_PLATFORM_CAPABILITIES : self .tpc
299308 }
300309 return params_PTQ
301310
@@ -307,23 +316,23 @@ def _setting_PTQ(self) -> Dict[str, Any]:
307316 dict: Parameter dictionary for PTQ.
308317 """
309318 params_QCfg = {
310- wrapper_const . ACTIVATION_ERROR_METHOD : self .params ['activation_error_method' ],
311- wrapper_const . WEIGHTS_ERROR_METHOD : mct .core .QuantizationErrorMethod .MSE ,
312- wrapper_const . WEIGHTS_BIAS_CORRECTION : self .params ['weights_bias_correction' ],
313- wrapper_const . Z_THRESHOLD : self .params ['z_threshold' ],
314- wrapper_const . LINEAR_COLLAPSING : self .params ['linear_collapsing' ],
315- wrapper_const . RESIDUAL_COLLAPSING : self .params ['residual_collapsing' ]
319+ ACTIVATION_ERROR_METHOD : self .params ['activation_error_method' ],
320+ WEIGHTS_ERROR_METHOD : mct .core .QuantizationErrorMethod .MSE ,
321+ WEIGHTS_BIAS_CORRECTION : self .params ['weights_bias_correction' ],
322+ Z_THRESHOLD : self .params ['z_threshold' ],
323+ LINEAR_COLLAPSING : self .params ['linear_collapsing' ],
324+ RESIDUAL_COLLAPSING : self .params ['residual_collapsing' ]
316325 }
317326 q_config = mct .core .QuantizationConfig (** params_QCfg )
318327 core_config = mct .core .CoreConfig (quantization_config = q_config )
319328 resource_utilization = None
320329
321330 params_PTQ = {
322331 self .argname_in_module : self .float_model ,
323- wrapper_const . REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
324- wrapper_const . TARGET_RESOURCE_UTILIZATION : resource_utilization ,
325- wrapper_const . CORE_CONFIG : core_config ,
326- wrapper_const . TARGET_PLATFORM_CAPABILITIES : self .tpc
332+ REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
333+ TARGET_RESOURCE_UTILIZATION : resource_utilization ,
334+ CORE_CONFIG : core_config ,
335+ TARGET_PLATFORM_CAPABILITIES : self .tpc
327336 }
328337 return params_PTQ
329338
@@ -341,16 +350,16 @@ def _setting_GPTQ_MixP(self) -> Dict[str, Any]:
341350 gptq_config = self .get_gptq_config (** params_GPTQCfg )
342351
343352 params_MPCfg = {
344- wrapper_const . NUM_OF_IMAGES : self .params ['num_of_images' ],
345- wrapper_const . USE_HESSIAN_BASED_SCORES : self .params ['use_hessian_based_scores' ],
353+ NUM_OF_IMAGES : self .params ['num_of_images' ],
354+ USE_HESSIAN_BASED_SCORES : self .params ['use_hessian_based_scores' ],
346355 }
347356 mixed_precision_config = mct .core .MixedPrecisionQuantizationConfig (** params_MPCfg )
348357 core_config = mct .core .CoreConfig (mixed_precision_config = mixed_precision_config )
349358 params_RUDCfg = {
350- wrapper_const . IN_MODEL : self .float_model ,
351- wrapper_const . REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
352- wrapper_const . CORE_CONFIG : core_config ,
353- wrapper_const . TARGET_PLATFORM_CAPABILITIES : self .tpc
359+ IN_MODEL : self .float_model ,
360+ REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
361+ CORE_CONFIG : core_config ,
362+ TARGET_PLATFORM_CAPABILITIES : self .tpc
354363 }
355364 ru_data = self .resource_utilization_data (** params_RUDCfg )
356365 weights_compression_ratio = (
@@ -366,11 +375,11 @@ def _setting_GPTQ_MixP(self) -> Dict[str, Any]:
366375
367376 params_GPTQ = {
368377 self .argname_model : self .float_model ,
369- wrapper_const . REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
370- wrapper_const . TARGET_RESOURCE_UTILIZATION : resource_utilization ,
371- wrapper_const . GPTQ_CONFIG : gptq_config ,
372- wrapper_const . CORE_CONFIG : core_config ,
373- wrapper_const . TARGET_PLATFORM_CAPABILITIES : self .tpc
378+ REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
379+ TARGET_RESOURCE_UTILIZATION : resource_utilization ,
380+ GPTQ_CONFIG : gptq_config ,
381+ CORE_CONFIG : core_config ,
382+ TARGET_PLATFORM_CAPABILITIES : self .tpc
374383 }
375384 return params_GPTQ
376385
@@ -389,9 +398,9 @@ def _setting_GPTQ(self) -> Dict[str, Any]:
389398
390399 params_GPTQ = {
391400 self .argname_model : self .float_model ,
392- wrapper_const . REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
393- wrapper_const . GPTQ_CONFIG : gptq_config ,
394- wrapper_const . TARGET_PLATFORM_CAPABILITIES : self .tpc
401+ REPRESENTATIVE_DATA_GEN : self .representative_dataset ,
402+ GPTQ_CONFIG : gptq_config ,
403+ TARGET_PLATFORM_CAPABILITIES : self .tpc
395404 }
396405 return params_GPTQ
397406
@@ -408,22 +417,6 @@ def _exec_lq_ptq(self) -> Any:
408417 Note:
409418 This method requires the lq_ptq module to be imported.
410419 """
411- model_save_dir , output_file_name = os .path .split (
412- self .params ['save_model_path' ])
413-
414- # Note: lq_ptq module should be imported when using this method
415- # q_model = lq_ptq.low_bit_quantizer_ptq(
416- # fp_model=self.float_model,
417- # representative_dataset=self.representative_dataset,
418- # model_save_dir=model_save_dir,
419- # output_file_name=output_file_name,
420- # learning_rate=self.params['learning_rate'],
421- # converter_ver=self.params['converter_ver'],
422- # debug_level='INFO',
423- # debug_detail=False,
424- # overwrite_output_file=True)
425- # return q_model
426-
427420 # Placeholder implementation - replace with actual lq_ptq call
428421 raise NotImplementedError (
429422 "LQ-PTQ functionality requires lq_ptq module to be imported" )
0 commit comments