1616from typing import Dict , Any , List , Optional , Tuple
1717import model_compression_toolkit as mct
1818from model_compression_toolkit .logger import Logger
19+ from model_compression_toolkit .verify_packages import FOUND_TPC
1920from model_compression_toolkit .wrapper .constants import (
20- REPRESENTATIVE_DATA_GEN , CORE_CONFIG , FW_NAME , SDSP_VERSION ,
21+ REPRESENTATIVE_DATA_GEN , CORE_CONFIG , FW_NAME , TARGET_PLATFORM_VERSION ,
22+ TARGET_PLATFORM_NAME , TPC_VERSION , DEVICE_TYPE , EXTENDED_VERSION ,
2123 NUM_OF_IMAGES , USE_HESSIAN_BASED_SCORES , IN_MODEL , IN_MODULE , MODEL ,
2224 TARGET_PLATFORM_CAPABILITIES , TARGET_RESOURCE_UTILIZATION ,
2325 ACTIVATION_ERROR_METHOD , WEIGHTS_ERROR_METHOD , WEIGHTS_BIAS_CORRECTION ,
@@ -54,6 +56,8 @@ def __init__(self):
5456 :header: "Parameter Key", "Default Value", "Description"
5557 :widths: 30, 30, 40
5658
59+ "target_platform_version", "'v1'", "Target platform version (use_internal_tpc=True)"
60+ "tpc_version", "'5.0'", "TPC version (use_internal_tpc=False)"
5761 "activation_error_method", "mct.core.QuantizationErrorMethod.MSE", "Activation quantization error method"
5862 "weights_bias_correction", "True", "Enable weights bias correction"
5963 "z_threshold", "float('inf')", "Z-threshold for quantization"
@@ -67,6 +71,8 @@ def __init__(self):
6771 :header: "Parameter Key", "Default Value", "Description"
6872 :widths: 30, 30, 40
6973
74+ "target_platform_version", "'v1'", "Target platform version (use_internal_tpc=True)"
75+ "tpc_version", "'5.0'", "TPC version (use_internal_tpc=False)"
7076 "num_of_images", "5", "Number of images for mixed precision"
7177 "use_hessian_based_scores", "False", "Use Hessian-based scores for mixed precision"
7278 "weights_compression_ratio", "None", "Weights compression ratio for resource util"
@@ -78,6 +84,8 @@ def __init__(self):
7884 :header: "Parameter Key", "Default Value", "Description"
7985 :widths: 30, 30, 40
8086
87+ "target_platform_version", "'v1'", "Target platform version (use_internal_tpc=True)"
88+ "tpc_version", "'5.0'", "TPC version (use_internal_tpc=False)"
8189 "n_epochs", "5", "Number of training epochs for GPTQ"
8290 "optimizer", "None", "Optimizer for GPTQ training"
8391 "save_model_path", "'./qmodel.keras' / './qmodel.onnx'", "Path to save quantized model (Keras/Pytorch)"
@@ -88,6 +96,8 @@ def __init__(self):
8896 :header: "Parameter Key", "Default Value", "Description"
8997 :widths: 30, 30, 40
9098
99+ "target_platform_version", "'v1'", "Target platform version (use_internal_tpc=True)"
100+ "tpc_version", "'5.0'", "TPC version (use_internal_tpc=False)"
91101 "n_epochs", "5", "Number of training epochs for GPTQ"
92102 "optimizer", "None", "Optimizer for GPTQ training"
93103 "num_of_images", "5", "Number of images for mixed precision"
@@ -99,7 +109,8 @@ def __init__(self):
99109 self .params : Dict [str , Any ] = {
100110 # TPC
101111 FW_NAME : 'pytorch' ,
102- SDSP_VERSION : '3.14' ,
112+ TARGET_PLATFORM_VERSION : 'v1' ,
113+ TPC_VERSION : '5.0' ,
103114
104115 # QuantizationConfig
105116 ACTIVATION_ERROR_METHOD : mct .core .QuantizationErrorMethod .MSE ,
@@ -129,8 +140,9 @@ def __init__(self):
129140
130141 def _initialize_and_validate (self , float_model : Any ,
131142 representative_dataset : Optional [Any ],
132- framework : str ,
133143 method : str ,
144+ framework : str ,
145+ use_internal_tpc : bool ,
134146 use_mixed_precision : bool
135147 ) -> None :
136148 """
@@ -139,8 +151,9 @@ def _initialize_and_validate(self, float_model: Any,
139151 Args:
140152 float_model: The float model to be quantized.
141153 representative_dataset (Callable, np.array, tf.Tensor): Representative dataset for calibration.
142- framework (str): Target framework ('tensorflow', 'pytorch').
143154 method (str): Quantization method ('PTQ', 'GPTQ', 'LQPTQ').
155+ framework (str): Target framework ('tensorflow', 'pytorch').
156+ use_internal_tpc (bool): Whether to use MCT's built-in TPC.
144157 use_mixed_precision (bool): Whether to use mixed-precision quantization.
145158
146159 Raises:
@@ -159,25 +172,29 @@ def _initialize_and_validate(self, float_model: Any,
159172 self .representative_dataset = representative_dataset
160173 self .method = method
161174 self .framework = framework
175+ self .use_internal_tpc = use_internal_tpc
162176 self .use_mixed_precision = use_mixed_precision
163177
164178 # Keep only the parameters you need for the quantization mode
165179 if method == 'PTQ' :
166180 if not use_mixed_precision :
167- allowed_keys = [ FW_NAME , SDSP_VERSION , ACTIVATION_ERROR_METHOD , WEIGHTS_BIAS_CORRECTION ,
181+ allowed_keys = [ FW_NAME , TARGET_PLATFORM_VERSION , TPC_VERSION ,
182+ ACTIVATION_ERROR_METHOD , WEIGHTS_BIAS_CORRECTION ,
168183 Z_THRESHOLD , LINEAR_COLLAPSING , RESIDUAL_COLLAPSING ,
169184 SAVE_MODEL_PATH ]
170185 else :
171- allowed_keys = [ FW_NAME , SDSP_VERSION , NUM_OF_IMAGES , USE_HESSIAN_BASED_SCORES ,
186+ allowed_keys = [ FW_NAME , TARGET_PLATFORM_VERSION , TPC_VERSION ,
187+ NUM_OF_IMAGES , USE_HESSIAN_BASED_SCORES ,
172188 WEIGHTS_COMPRESSION_RATIO , SAVE_MODEL_PATH ]
173189 else :
174190 if not use_mixed_precision :
175- allowed_keys = [FW_NAME , SDSP_VERSION , N_EPOCHS , OPTIMIZER ,
176- SAVE_MODEL_PATH ]
191+ allowed_keys = [ FW_NAME , TARGET_PLATFORM_VERSION , TPC_VERSION ,
192+ N_EPOCHS , OPTIMIZER , SAVE_MODEL_PATH ]
177193 else :
178- allowed_keys = [FW_NAME , SDSP_VERSION , N_EPOCHS , OPTIMIZER ,
179- NUM_OF_IMAGES , USE_HESSIAN_BASED_SCORES ,
180- WEIGHTS_COMPRESSION_RATIO , SAVE_MODEL_PATH ]
194+ allowed_keys = [ FW_NAME , TARGET_PLATFORM_VERSION , TPC_VERSION ,
195+ N_EPOCHS , OPTIMIZER , NUM_OF_IMAGES ,
196+ USE_HESSIAN_BASED_SCORES , WEIGHTS_COMPRESSION_RATIO ,
197+ SAVE_MODEL_PATH ]
181198
182199 self .params = { k : v for k , v in self .params .items () if k in allowed_keys }
183200
@@ -296,18 +313,36 @@ def _select_argname(self) -> None:
296313
297314 def _get_tpc (self ) -> None :
298315 """
299- Configure Target Platform Capabilities (TPC).
316+ Configure Target Platform Capabilities (TPC) based on selected option .
300317
301- Sets up TPC configuration for the target platform.
318+ Sets up either MCT's built-in TPC or external EdgeMDT TPC configuration
319+ for the IMX500 target platform.
302320
303321 Note:
304322 This method sets self.tpc attribute with the configured TPC object.
305323 """
306- # Get default TPC for the framework
307- params_TPC = {
308- SDSP_VERSION : self .params [SDSP_VERSION ]
309- }
310- self .tpc = mct .get_target_platform_capabilities_sdsp (** params_TPC )
324+ if self .use_internal_tpc :
325+ # Use MCT's built-in TPC configuration
326+ params_TPC = {
327+ FW_NAME : self .params [FW_NAME ],
328+ TARGET_PLATFORM_NAME : 'imx500' ,
329+ TARGET_PLATFORM_VERSION : self .params [TARGET_PLATFORM_VERSION ],
330+ }
331+ # Get TPC from MCT framework
332+ self .tpc = mct .get_target_platform_capabilities (** params_TPC )
333+ else :
334+ if FOUND_TPC :
335+ import edgemdt_tpc
336+ # Use external EdgeMDT TPC configuration
337+ params_TPC = {
338+ TPC_VERSION : self .params [TPC_VERSION ],
339+ DEVICE_TYPE : 'imx500' ,
340+ EXTENDED_VERSION : None
341+ }
342+ # Get TPC from EdgeMDT framework
343+ self .tpc = edgemdt_tpc .get_target_platform_capabilities (** params_TPC )
344+ else :
345+ raise Exception ("EdgeMDT TPC module is not available." )
311346
312347 def _setting_PTQ_mixed_precision (self ) -> Dict [str , Any ]:
313348 """
@@ -475,6 +510,8 @@ def _export_model(self, quantized_model: Any) -> None:
475510 params_export = {
476511 'model' : quantized_model ,
477512 'save_model_path' : self .params ['save_model_path' ],
513+ 'serialization_format' : (mct .exporter .KerasExportSerializationFormat .KERAS ),
514+ 'quantization_format' : (mct .exporter .QuantizationFormat .FAKELY_QUANT )
478515 }
479516 elif self .framework == 'pytorch' :
480517 params_export = {
@@ -486,8 +523,9 @@ def _export_model(self, quantized_model: Any) -> None:
486523
487524 def quantize_and_export (self , float_model : Any ,
488525 representative_dataset : Any ,
489- framework : str = 'pytorch' ,
490526 method : str = 'PTQ' ,
527+ framework : str = 'pytorch' ,
528+ use_internal_tpc : bool = True ,
491529 use_mixed_precision : bool = False ,
492530 param_items : Optional [List [List [Any ]]] = None
493531 ) -> Tuple [bool , Any ]:
@@ -498,10 +536,12 @@ def quantize_and_export(self, float_model: Any,
498536 float_model: The float model to be quantized.
499537 representative_dataset (Callable, np.array, tf.Tensor):
500538 Representative dataset for calibration.
501- framework (str): 'tensorflow' or 'pytorch'.
502- Default: 'pytorch'
503539 method (str): Quantization method, e.g., 'PTQ' or 'GPTQ'.
504540 Default: 'PTQ'
541+ framework (str): 'tensorflow' or 'pytorch'.
542+ Default: 'pytorch'
543+ use_internal_tpc (bool): Whether to use internal_tpc.
544+ Default: True
505545 use_mixed_precision (bool): Whether to use mixed-precision
506546 quantization. Default: False
507547 param_items (list): List of parameter settings.
@@ -525,10 +565,11 @@ def quantize_and_export(self, float_model: Any,
525565
526566 >>> wrapper = mct.MCTWrapper()
527567
528- set framework, method , and other parameters
568+ set method, framework , and other parameters
529569
530- >>> framework = 'tensorflow'
531570 >>> method = 'PTQ'
571+ >>> framework = 'tensorflow'
572+ >>> use_internal_tpc = True
532573 >>> use_mixed_precision = False
533574
534575 set parameters if needed
@@ -540,17 +581,19 @@ def quantize_and_export(self, float_model: Any,
540581 >>> flag, quantized_model = wrapper.quantize_and_export(
541582 ... float_model=float_model,
542583 ... representative_dataset=representative_dataset,
543- ... framework=framework,
544584 ... method=method,
585+ ... framework=framework,
586+ ... use_internal_tpc=use_internal_tpc,
545587 ... use_mixed_precision=use_mixed_precision,
546588 ... param_items=param_items
547589 ... )
548590
549591 """
550592 try :
551593 # Step 1: Initialize and validate all input parameters
552- self ._initialize_and_validate ( float_model , representative_dataset ,
553- framework , method , use_mixed_precision )
594+ self ._initialize_and_validate (
595+ float_model , representative_dataset , method , framework ,
596+ use_internal_tpc , use_mixed_precision )
554597
555598 # Step 2: Apply custom parameter modifications
556599 self ._modify_params (param_items )
0 commit comments