|
14 | 14 | # ============================================================================== |
15 | 15 | import copy |
16 | 16 | from typing import Callable, Union, Optional, Tuple |
| 17 | +from tqdm.contrib.logging import logging_redirect_tqdm |
17 | 18 |
|
18 | 19 | from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES |
19 | 20 | from model_compression_toolkit.core import CoreConfig |
|
39 | 40 | from model_compression_toolkit.verify_packages import FOUND_TORCH |
40 | 41 |
|
41 | 42 |
|
42 | | - |
43 | 43 | if FOUND_TORCH: |
44 | 44 | from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO |
45 | 45 | from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation |
@@ -209,82 +209,84 @@ def pytorch_gradient_post_training_quantization(model: Module, |
209 | 209 |
|
210 | 210 | """ |
211 | 211 |
|
212 | | - if core_config.debug_config.bypass: |
213 | | - return model, None |
214 | | - |
215 | | - if core_config.is_mixed_precision_enabled: # pragma: no cover |
216 | | - if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig): |
217 | | - Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. " |
218 | | - "Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' " |
219 | | - "or provide a valid mixed-precision configuration.") |
220 | | - |
221 | | - tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO) |
222 | | - |
223 | | - fw_impl = GPTQPytorchImplemantation() |
224 | | - |
225 | | - target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) |
226 | | - # Attach tpc model to framework |
227 | | - attach2pytorch = AttachTpcToPytorch() |
228 | | - framework_quantization_capabilities = attach2pytorch.attach(target_platform_capabilities, |
229 | | - core_config.quantization_config.custom_tpc_opset_to_layer) |
230 | | - |
231 | | - progress_info_controller = ProgressInfoController( |
232 | | - total_step=research_progress_total(core_config, target_resource_utilization, gptq_config), |
233 | | - description="MCT PyTorch GPTQ Progress", |
234 | | - progress_info_callback=core_config.debug_config.progress_info_callback |
235 | | - ) |
236 | | - |
237 | | - # ---------------------- # |
238 | | - # Core Runner |
239 | | - # ---------------------- # |
240 | | - graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model, |
241 | | - representative_data_gen=representative_data_gen, |
242 | | - core_config=core_config, |
243 | | - fw_info=DEFAULT_PYTORCH_INFO, |
244 | | - fw_impl=fw_impl, |
245 | | - fqc=framework_quantization_capabilities, |
246 | | - target_resource_utilization=target_resource_utilization, |
247 | | - tb_w=tb_w, |
248 | | - running_gptq=True, |
249 | | - progress_info_controller=progress_info_controller) |
250 | | - |
251 | | - float_graph = copy.deepcopy(graph) |
252 | | - |
253 | | - # ---------------------- # |
254 | | - # GPTQ Runner |
255 | | - # ---------------------- # |
256 | | - graph_gptq = gptq_runner(graph, |
257 | | - core_config, |
258 | | - gptq_config, |
259 | | - representative_data_gen, |
260 | | - gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen, |
261 | | - DEFAULT_PYTORCH_INFO, |
262 | | - fw_impl, |
263 | | - tb_w, |
264 | | - hessian_info_service=hessian_info_service, |
265 | | - progress_info_controller=progress_info_controller) |
266 | | - |
267 | | - if progress_info_controller is not None: |
268 | | - progress_info_controller.set_description("MCT Graph Finalization") |
269 | | - |
270 | | - if core_config.debug_config.analyze_similarity: |
271 | | - analyzer_model_quantization(representative_data_gen, |
272 | | - tb_w, |
273 | | - float_graph, |
274 | | - graph_gptq, |
275 | | - fw_impl, |
276 | | - DEFAULT_PYTORCH_INFO) |
277 | | - |
278 | | - exportable_model, user_info = get_exportable_pytorch_model(graph_gptq) |
279 | | - if framework_quantization_capabilities.tpc.add_metadata: |
280 | | - exportable_model = add_metadata(exportable_model, |
281 | | - create_model_metadata(fqc=framework_quantization_capabilities, |
282 | | - scheduling_info=scheduling_info)) |
283 | | - |
284 | | - if progress_info_controller is not None: |
285 | | - progress_info_controller.close() |
286 | | - |
287 | | - return exportable_model, user_info |
| 212 | + with logging_redirect_tqdm(): |
| 213 | + |
| 214 | + if core_config.debug_config.bypass: |
| 215 | + return model, None |
| 216 | + |
| 217 | + if core_config.is_mixed_precision_enabled: # pragma: no cover |
| 218 | + if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig): |
| 219 | + Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. " |
| 220 | + "Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' " |
| 221 | + "or provide a valid mixed-precision configuration.") |
| 222 | + |
| 223 | + tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO) |
| 224 | + |
| 225 | + fw_impl = GPTQPytorchImplemantation() |
| 226 | + |
| 227 | + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) |
| 228 | + # Attach tpc model to framework |
| 229 | + attach2pytorch = AttachTpcToPytorch() |
| 230 | + framework_quantization_capabilities = attach2pytorch.attach(target_platform_capabilities, |
| 231 | + core_config.quantization_config.custom_tpc_opset_to_layer) |
| 232 | + |
| 233 | + progress_info_controller = ProgressInfoController( |
| 234 | + total_step=research_progress_total(core_config, target_resource_utilization, gptq_config), |
| 235 | + description="MCT PyTorch GPTQ Progress", |
| 236 | + progress_info_callback=core_config.debug_config.progress_info_callback |
| 237 | + ) |
| 238 | + |
| 239 | + # ---------------------- # |
| 240 | + # Core Runner |
| 241 | + # ---------------------- # |
| 242 | + graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model, |
| 243 | + representative_data_gen=representative_data_gen, |
| 244 | + core_config=core_config, |
| 245 | + fw_info=DEFAULT_PYTORCH_INFO, |
| 246 | + fw_impl=fw_impl, |
| 247 | + fqc=framework_quantization_capabilities, |
| 248 | + target_resource_utilization=target_resource_utilization, |
| 249 | + tb_w=tb_w, |
| 250 | + running_gptq=True, |
| 251 | + progress_info_controller=progress_info_controller) |
| 252 | + |
| 253 | + float_graph = copy.deepcopy(graph) |
| 254 | + |
| 255 | + # ---------------------- # |
| 256 | + # GPTQ Runner |
| 257 | + # ---------------------- # |
| 258 | + graph_gptq = gptq_runner(graph, |
| 259 | + core_config, |
| 260 | + gptq_config, |
| 261 | + representative_data_gen, |
| 262 | + gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen, |
| 263 | + DEFAULT_PYTORCH_INFO, |
| 264 | + fw_impl, |
| 265 | + tb_w, |
| 266 | + hessian_info_service=hessian_info_service, |
| 267 | + progress_info_controller=progress_info_controller) |
| 268 | + |
| 269 | + if progress_info_controller is not None: |
| 270 | + progress_info_controller.set_description("MCT Graph Finalization") |
| 271 | + |
| 272 | + if core_config.debug_config.analyze_similarity: |
| 273 | + analyzer_model_quantization(representative_data_gen, |
| 274 | + tb_w, |
| 275 | + float_graph, |
| 276 | + graph_gptq, |
| 277 | + fw_impl, |
| 278 | + DEFAULT_PYTORCH_INFO) |
| 279 | + |
| 280 | + exportable_model, user_info = get_exportable_pytorch_model(graph_gptq) |
| 281 | + if framework_quantization_capabilities.tpc.add_metadata: |
| 282 | + exportable_model = add_metadata(exportable_model, |
| 283 | + create_model_metadata(fqc=framework_quantization_capabilities, |
| 284 | + scheduling_info=scheduling_info)) |
| 285 | + |
| 286 | + if progress_info_controller is not None: |
| 287 | + progress_info_controller.close() |
| 288 | + |
| 289 | + return exportable_model, user_info |
288 | 290 |
|
289 | 291 |
|
290 | 292 | else: |
|
0 commit comments