Skip to content

Commit eb692d2

Browse files
committed
adding logging_redirect_tqdm
1 parent f8a4b80 commit eb692d2

File tree

4 files changed

+316
-305
lines changed

4 files changed

+316
-305
lines changed

model_compression_toolkit/gptq/keras/quantization_facade.py

Lines changed: 79 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Callable, Tuple, Union, Optional
1818
from packaging import version
19+
from tqdm.contrib.logging import logging_redirect_tqdm
1920

2021
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
2122
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
@@ -232,82 +233,84 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da
232233
233234
"""
234235

235-
if core_config.debug_config.bypass:
236-
return in_model, None
237-
238-
KerasModelValidation(model=in_model,
239-
fw_info=DEFAULT_KERAS_INFO).validate()
240-
241-
if core_config.is_mixed_precision_enabled:
242-
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
243-
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
244-
"Ensure usage of the correct API for keras_post_training_quantization "
245-
"or provide a valid mixed-precision configuration.") # pragma: no cover
246-
247-
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
248-
249-
fw_impl = GPTQKerasImplemantation()
250-
251-
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
252-
# Attach tpc model to framework
253-
attach2keras = AttachTpcToKeras()
254-
framework_platform_capabilities = attach2keras.attach(
255-
target_platform_capabilities,
256-
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
257-
258-
progress_info_controller = ProgressInfoController(
259-
total_step=research_progress_total(core_config, target_resource_utilization, gptq_config),
260-
description="MCT Keras GPTQ Progress",
261-
progress_info_callback=core_config.debug_config.progress_info_callback
262-
)
263-
264-
tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
265-
representative_data_gen=representative_data_gen,
266-
core_config=core_config,
267-
fw_info=DEFAULT_KERAS_INFO,
268-
fw_impl=fw_impl,
269-
fqc=framework_platform_capabilities,
270-
target_resource_utilization=target_resource_utilization,
271-
tb_w=tb_w,
272-
running_gptq=True,
273-
progress_info_controller=progress_info_controller)
274-
275-
float_graph = copy.deepcopy(tg)
276-
277-
tg_gptq = gptq_runner(tg,
278-
core_config,
279-
gptq_config,
280-
representative_data_gen,
281-
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
282-
DEFAULT_KERAS_INFO,
283-
fw_impl,
284-
tb_w,
285-
hessian_info_service=hessian_info_service,
286-
progress_info_controller=progress_info_controller)
287-
288-
del hessian_info_service
289-
290-
if progress_info_controller is not None:
291-
progress_info_controller.set_description("MCT Graph Finalization")
292-
293-
if core_config.debug_config.analyze_similarity:
294-
analyzer_model_quantization(representative_data_gen,
295-
tb_w,
296-
float_graph,
297-
tg_gptq,
298-
fw_impl,
299-
DEFAULT_KERAS_INFO)
300-
301-
exportable_model, user_info = get_exportable_keras_model(tg_gptq)
302-
if framework_platform_capabilities.tpc.add_metadata:
303-
exportable_model = add_metadata(exportable_model,
304-
create_model_metadata(fqc=framework_platform_capabilities,
305-
scheduling_info=scheduling_info))
306-
307-
if progress_info_controller is not None:
308-
progress_info_controller.close()
309-
310-
return exportable_model, user_info
236+
with logging_redirect_tqdm():
237+
238+
if core_config.debug_config.bypass:
239+
return in_model, None
240+
241+
KerasModelValidation(model=in_model,
242+
fw_info=DEFAULT_KERAS_INFO).validate()
243+
244+
if core_config.is_mixed_precision_enabled:
245+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
246+
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
247+
"Ensure usage of the correct API for keras_post_training_quantization "
248+
"or provide a valid mixed-precision configuration.") # pragma: no cover
249+
250+
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
251+
252+
fw_impl = GPTQKerasImplemantation()
253+
254+
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
255+
# Attach tpc model to framework
256+
attach2keras = AttachTpcToKeras()
257+
framework_platform_capabilities = attach2keras.attach(
258+
target_platform_capabilities,
259+
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
260+
261+
progress_info_controller = ProgressInfoController(
262+
total_step=research_progress_total(core_config, target_resource_utilization, gptq_config),
263+
description="MCT Keras GPTQ Progress",
264+
progress_info_callback=core_config.debug_config.progress_info_callback
265+
)
266+
267+
tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
268+
representative_data_gen=representative_data_gen,
269+
core_config=core_config,
270+
fw_info=DEFAULT_KERAS_INFO,
271+
fw_impl=fw_impl,
272+
fqc=framework_platform_capabilities,
273+
target_resource_utilization=target_resource_utilization,
274+
tb_w=tb_w,
275+
running_gptq=True,
276+
progress_info_controller=progress_info_controller)
277+
278+
float_graph = copy.deepcopy(tg)
279+
280+
tg_gptq = gptq_runner(tg,
281+
core_config,
282+
gptq_config,
283+
representative_data_gen,
284+
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
285+
DEFAULT_KERAS_INFO,
286+
fw_impl,
287+
tb_w,
288+
hessian_info_service=hessian_info_service,
289+
progress_info_controller=progress_info_controller)
290+
291+
del hessian_info_service
292+
293+
if progress_info_controller is not None:
294+
progress_info_controller.set_description("MCT Graph Finalization")
295+
296+
if core_config.debug_config.analyze_similarity:
297+
analyzer_model_quantization(representative_data_gen,
298+
tb_w,
299+
float_graph,
300+
tg_gptq,
301+
fw_impl,
302+
DEFAULT_KERAS_INFO)
303+
304+
exportable_model, user_info = get_exportable_keras_model(tg_gptq)
305+
if framework_platform_capabilities.tpc.add_metadata:
306+
exportable_model = add_metadata(exportable_model,
307+
create_model_metadata(fqc=framework_platform_capabilities,
308+
scheduling_info=scheduling_info))
309+
310+
if progress_info_controller is not None:
311+
progress_info_controller.close()
312+
313+
return exportable_model, user_info
311314

312315
else:
313316
# If tensorflow is not installed,

model_compression_toolkit/gptq/pytorch/quantization_facade.py

Lines changed: 79 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
import copy
1616
from typing import Callable, Union, Optional, Tuple
17+
from tqdm.contrib.logging import logging_redirect_tqdm
1718

1819
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
1920
from model_compression_toolkit.core import CoreConfig
@@ -39,7 +40,6 @@
3940
from model_compression_toolkit.verify_packages import FOUND_TORCH
4041

4142

42-
4343
if FOUND_TORCH:
4444
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
4545
from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
@@ -209,82 +209,84 @@ def pytorch_gradient_post_training_quantization(model: Module,
209209
210210
"""
211211

212-
if core_config.debug_config.bypass:
213-
return model, None
214-
215-
if core_config.is_mixed_precision_enabled: # pragma: no cover
216-
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
217-
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
218-
"Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
219-
"or provide a valid mixed-precision configuration.")
220-
221-
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
222-
223-
fw_impl = GPTQPytorchImplemantation()
224-
225-
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
226-
# Attach tpc model to framework
227-
attach2pytorch = AttachTpcToPytorch()
228-
framework_quantization_capabilities = attach2pytorch.attach(target_platform_capabilities,
229-
core_config.quantization_config.custom_tpc_opset_to_layer)
230-
231-
progress_info_controller = ProgressInfoController(
232-
total_step=research_progress_total(core_config, target_resource_utilization, gptq_config),
233-
description="MCT PyTorch GPTQ Progress",
234-
progress_info_callback=core_config.debug_config.progress_info_callback
235-
)
236-
237-
# ---------------------- #
238-
# Core Runner
239-
# ---------------------- #
240-
graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model,
241-
representative_data_gen=representative_data_gen,
242-
core_config=core_config,
243-
fw_info=DEFAULT_PYTORCH_INFO,
244-
fw_impl=fw_impl,
245-
fqc=framework_quantization_capabilities,
246-
target_resource_utilization=target_resource_utilization,
247-
tb_w=tb_w,
248-
running_gptq=True,
249-
progress_info_controller=progress_info_controller)
250-
251-
float_graph = copy.deepcopy(graph)
252-
253-
# ---------------------- #
254-
# GPTQ Runner
255-
# ---------------------- #
256-
graph_gptq = gptq_runner(graph,
257-
core_config,
258-
gptq_config,
259-
representative_data_gen,
260-
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
261-
DEFAULT_PYTORCH_INFO,
262-
fw_impl,
263-
tb_w,
264-
hessian_info_service=hessian_info_service,
265-
progress_info_controller=progress_info_controller)
266-
267-
if progress_info_controller is not None:
268-
progress_info_controller.set_description("MCT Graph Finalization")
269-
270-
if core_config.debug_config.analyze_similarity:
271-
analyzer_model_quantization(representative_data_gen,
272-
tb_w,
273-
float_graph,
274-
graph_gptq,
275-
fw_impl,
276-
DEFAULT_PYTORCH_INFO)
277-
278-
exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
279-
if framework_quantization_capabilities.tpc.add_metadata:
280-
exportable_model = add_metadata(exportable_model,
281-
create_model_metadata(fqc=framework_quantization_capabilities,
282-
scheduling_info=scheduling_info))
283-
284-
if progress_info_controller is not None:
285-
progress_info_controller.close()
286-
287-
return exportable_model, user_info
212+
with logging_redirect_tqdm():
213+
214+
if core_config.debug_config.bypass:
215+
return model, None
216+
217+
if core_config.is_mixed_precision_enabled: # pragma: no cover
218+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
219+
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
220+
"Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
221+
"or provide a valid mixed-precision configuration.")
222+
223+
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
224+
225+
fw_impl = GPTQPytorchImplemantation()
226+
227+
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
228+
# Attach tpc model to framework
229+
attach2pytorch = AttachTpcToPytorch()
230+
framework_quantization_capabilities = attach2pytorch.attach(target_platform_capabilities,
231+
core_config.quantization_config.custom_tpc_opset_to_layer)
232+
233+
progress_info_controller = ProgressInfoController(
234+
total_step=research_progress_total(core_config, target_resource_utilization, gptq_config),
235+
description="MCT PyTorch GPTQ Progress",
236+
progress_info_callback=core_config.debug_config.progress_info_callback
237+
)
238+
239+
# ---------------------- #
240+
# Core Runner
241+
# ---------------------- #
242+
graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model,
243+
representative_data_gen=representative_data_gen,
244+
core_config=core_config,
245+
fw_info=DEFAULT_PYTORCH_INFO,
246+
fw_impl=fw_impl,
247+
fqc=framework_quantization_capabilities,
248+
target_resource_utilization=target_resource_utilization,
249+
tb_w=tb_w,
250+
running_gptq=True,
251+
progress_info_controller=progress_info_controller)
252+
253+
float_graph = copy.deepcopy(graph)
254+
255+
# ---------------------- #
256+
# GPTQ Runner
257+
# ---------------------- #
258+
graph_gptq = gptq_runner(graph,
259+
core_config,
260+
gptq_config,
261+
representative_data_gen,
262+
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
263+
DEFAULT_PYTORCH_INFO,
264+
fw_impl,
265+
tb_w,
266+
hessian_info_service=hessian_info_service,
267+
progress_info_controller=progress_info_controller)
268+
269+
if progress_info_controller is not None:
270+
progress_info_controller.set_description("MCT Graph Finalization")
271+
272+
if core_config.debug_config.analyze_similarity:
273+
analyzer_model_quantization(representative_data_gen,
274+
tb_w,
275+
float_graph,
276+
graph_gptq,
277+
fw_impl,
278+
DEFAULT_PYTORCH_INFO)
279+
280+
exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
281+
if framework_quantization_capabilities.tpc.add_metadata:
282+
exportable_model = add_metadata(exportable_model,
283+
create_model_metadata(fqc=framework_quantization_capabilities,
284+
scheduling_info=scheduling_info))
285+
286+
if progress_info_controller is not None:
287+
progress_info_controller.close()
288+
289+
return exportable_model, user_info
288290

289291

290292
else:

0 commit comments

Comments
 (0)