|
16 | 16 |
|
17 | 17 | from typing import Callable, Tuple, Union, Optional |
18 | 18 | from packaging import version |
| 19 | +from tqdm.contrib.logging import logging_redirect_tqdm |
19 | 20 |
|
20 | 21 | from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer |
21 | 22 | from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \ |
@@ -232,82 +233,84 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da |
232 | 233 |
|
233 | 234 | """ |
234 | 235 |
|
235 | | - if core_config.debug_config.bypass: |
236 | | - return in_model, None |
237 | | - |
238 | | - KerasModelValidation(model=in_model, |
239 | | - fw_info=DEFAULT_KERAS_INFO).validate() |
240 | | - |
241 | | - if core_config.is_mixed_precision_enabled: |
242 | | - if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig): |
243 | | - Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. " |
244 | | - "Ensure usage of the correct API for keras_post_training_quantization " |
245 | | - "or provide a valid mixed-precision configuration.") # pragma: no cover |
246 | | - |
247 | | - tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO) |
248 | | - |
249 | | - fw_impl = GPTQKerasImplemantation() |
250 | | - |
251 | | - target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) |
252 | | - # Attach tpc model to framework |
253 | | - attach2keras = AttachTpcToKeras() |
254 | | - framework_platform_capabilities = attach2keras.attach( |
255 | | - target_platform_capabilities, |
256 | | - custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer) |
257 | | - |
258 | | - progress_info_controller = ProgressInfoController( |
259 | | - total_step=research_progress_total(core_config, target_resource_utilization, gptq_config), |
260 | | - description="MCT Keras GPTQ Progress", |
261 | | - progress_info_callback=core_config.debug_config.progress_info_callback |
262 | | - ) |
263 | | - |
264 | | - tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model, |
265 | | - representative_data_gen=representative_data_gen, |
266 | | - core_config=core_config, |
267 | | - fw_info=DEFAULT_KERAS_INFO, |
268 | | - fw_impl=fw_impl, |
269 | | - fqc=framework_platform_capabilities, |
270 | | - target_resource_utilization=target_resource_utilization, |
271 | | - tb_w=tb_w, |
272 | | - running_gptq=True, |
273 | | - progress_info_controller=progress_info_controller) |
274 | | - |
275 | | - float_graph = copy.deepcopy(tg) |
276 | | - |
277 | | - tg_gptq = gptq_runner(tg, |
278 | | - core_config, |
279 | | - gptq_config, |
280 | | - representative_data_gen, |
281 | | - gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen, |
282 | | - DEFAULT_KERAS_INFO, |
283 | | - fw_impl, |
284 | | - tb_w, |
285 | | - hessian_info_service=hessian_info_service, |
286 | | - progress_info_controller=progress_info_controller) |
287 | | - |
288 | | - del hessian_info_service |
289 | | - |
290 | | - if progress_info_controller is not None: |
291 | | - progress_info_controller.set_description("MCT Graph Finalization") |
292 | | - |
293 | | - if core_config.debug_config.analyze_similarity: |
294 | | - analyzer_model_quantization(representative_data_gen, |
295 | | - tb_w, |
296 | | - float_graph, |
297 | | - tg_gptq, |
298 | | - fw_impl, |
299 | | - DEFAULT_KERAS_INFO) |
300 | | - |
301 | | - exportable_model, user_info = get_exportable_keras_model(tg_gptq) |
302 | | - if framework_platform_capabilities.tpc.add_metadata: |
303 | | - exportable_model = add_metadata(exportable_model, |
304 | | - create_model_metadata(fqc=framework_platform_capabilities, |
305 | | - scheduling_info=scheduling_info)) |
306 | | - |
307 | | - if progress_info_controller is not None: |
308 | | - progress_info_controller.close() |
309 | | - |
310 | | - return exportable_model, user_info |
| 236 | + with logging_redirect_tqdm(): |
| 237 | + |
| 238 | + if core_config.debug_config.bypass: |
| 239 | + return in_model, None |
| 240 | + |
| 241 | + KerasModelValidation(model=in_model, |
| 242 | + fw_info=DEFAULT_KERAS_INFO).validate() |
| 243 | + |
| 244 | + if core_config.is_mixed_precision_enabled: |
| 245 | + if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig): |
| 246 | + Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. " |
| 247 | + "Ensure usage of the correct API for keras_post_training_quantization " |
| 248 | + "or provide a valid mixed-precision configuration.") # pragma: no cover |
| 249 | + |
| 250 | + tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO) |
| 251 | + |
| 252 | + fw_impl = GPTQKerasImplemantation() |
| 253 | + |
| 254 | + target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities) |
| 255 | + # Attach tpc model to framework |
| 256 | + attach2keras = AttachTpcToKeras() |
| 257 | + framework_platform_capabilities = attach2keras.attach( |
| 258 | + target_platform_capabilities, |
| 259 | + custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer) |
| 260 | + |
| 261 | + progress_info_controller = ProgressInfoController( |
| 262 | + total_step=research_progress_total(core_config, target_resource_utilization, gptq_config), |
| 263 | + description="MCT Keras GPTQ Progress", |
| 264 | + progress_info_callback=core_config.debug_config.progress_info_callback |
| 265 | + ) |
| 266 | + |
| 267 | + tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model, |
| 268 | + representative_data_gen=representative_data_gen, |
| 269 | + core_config=core_config, |
| 270 | + fw_info=DEFAULT_KERAS_INFO, |
| 271 | + fw_impl=fw_impl, |
| 272 | + fqc=framework_platform_capabilities, |
| 273 | + target_resource_utilization=target_resource_utilization, |
| 274 | + tb_w=tb_w, |
| 275 | + running_gptq=True, |
| 276 | + progress_info_controller=progress_info_controller) |
| 277 | + |
| 278 | + float_graph = copy.deepcopy(tg) |
| 279 | + |
| 280 | + tg_gptq = gptq_runner(tg, |
| 281 | + core_config, |
| 282 | + gptq_config, |
| 283 | + representative_data_gen, |
| 284 | + gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen, |
| 285 | + DEFAULT_KERAS_INFO, |
| 286 | + fw_impl, |
| 287 | + tb_w, |
| 288 | + hessian_info_service=hessian_info_service, |
| 289 | + progress_info_controller=progress_info_controller) |
| 290 | + |
| 291 | + del hessian_info_service |
| 292 | + |
| 293 | + if progress_info_controller is not None: |
| 294 | + progress_info_controller.set_description("MCT Graph Finalization") |
| 295 | + |
| 296 | + if core_config.debug_config.analyze_similarity: |
| 297 | + analyzer_model_quantization(representative_data_gen, |
| 298 | + tb_w, |
| 299 | + float_graph, |
| 300 | + tg_gptq, |
| 301 | + fw_impl, |
| 302 | + DEFAULT_KERAS_INFO) |
| 303 | + |
| 304 | + exportable_model, user_info = get_exportable_keras_model(tg_gptq) |
| 305 | + if framework_platform_capabilities.tpc.add_metadata: |
| 306 | + exportable_model = add_metadata(exportable_model, |
| 307 | + create_model_metadata(fqc=framework_platform_capabilities, |
| 308 | + scheduling_info=scheduling_info)) |
| 309 | + |
| 310 | + if progress_info_controller is not None: |
| 311 | + progress_info_controller.close() |
| 312 | + |
| 313 | + return exportable_model, user_info |
311 | 314 |
|
312 | 315 | else: |
313 | 316 | # If tensorflow is not installed, |
|
0 commit comments