Skip to content

Commit 09dfd0e

Browse files
authored
[RedSuns Review]Midify progress bar visualizaion
1 parent 95cc888 commit 09dfd0e

File tree

14 files changed

+524
-311
lines changed

14 files changed

+524
-311
lines changed

docs/api/api_docs/modules/debug_config.html

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ <h2>DebugConfig<a class="headerlink" href="#debugconfig" title="Link to this hea
134134
<span class="gp">&gt;&gt;&gt; </span><span class="n">core_config</span> <span class="o">=</span> <span class="n">mct</span><span class="o">.</span><span class="n">core</span><span class="o">.</span><span class="n">CoreConfig</span><span class="p">(</span><span class="n">debug_config</span><span class="o">=</span><span class="n">debug_config</span><span class="p">)</span>
135135
</pre></div>
136136
</div>
137+
<div class="admonition important">
138+
<p class="admonition-title">Important</p>
139+
<p>If a callback function is configured, the GPTQ data iteration progress bar is disabled and not displayed.</p>
140+
</div>
137141
</dd></dl>
138142

139143
</section>

docs/searchindex.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

model_compression_toolkit/core/common/progress_config/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@
2020
PROGRESS_INFO_CALLBACK = 'progress_info_callback'
2121
TOTAL_STEP = 'total_step'
2222

23-
PROGRESS_BAR_POSITION = 2
23+
PROGRESS_BAR_POSITION = 1
2424
DEFAULT_TOTAL_STEP = 4

model_compression_toolkit/core/common/progress_config/progress_info_controller.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def set_description(self, description: str):
9595
self.close()
9696
raise
9797

98-
self.pbar.set_description(formatted_description, refresh=False)
99-
self.pbar.update()
98+
self.pbar.n += 1
99+
self.pbar.set_description(formatted_description, refresh=True)
100100

101101
progress_info = {
102102
COMPLETED_COMPONENTS: description,

model_compression_toolkit/core/common/quantization/debug_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ class DebugConfig:
103103
>>> import model_compression_toolkit as mct
104104
>>> debug_config = mct.core.DebugConfig(progress_info_callback=progress_info_callback)
105105
>>> core_config = mct.core.CoreConfig(debug_config=debug_config)
106+
107+
.. important::
108+
If a callback function is configured, the GPTQ data iteration progress bar is disabled and not displayed.
109+
106110
"""
107111

108112
analyze_similarity: bool = False

model_compression_toolkit/gptq/common/gptq_training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self,
7373
self.fw_info = fw_info
7474
self.representative_data_gen_fn = representative_data_gen_fn
7575
self.progress_info_controller = progress_info_controller
76+
self.disable_data_pbar = progress_info_controller is not None
7677

7778
def _get_total_grad_steps():
7879
return get_total_grad_steps(representative_data_gen_fn) * gptq_config.n_epochs

model_compression_toolkit/gptq/keras/gptq_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def micro_training_loop(self,
396396
"""
397397
with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
398398
for _ in epochs_pbar:
399-
with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
399+
with tqdm(self.train_dataloader, position=1, leave=False, disable=self.disable_data_pbar) as data_pbar:
400400
for data in data_pbar:
401401

402402
input_data, distill_loss_weights, reg_weight = data

model_compression_toolkit/gptq/keras/quantization_facade.py

Lines changed: 79 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Callable, Tuple, Union, Optional
1818
from packaging import version
19+
from tqdm.contrib.logging import logging_redirect_tqdm
1920

2021
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
2122
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
@@ -232,82 +233,84 @@ def keras_gradient_post_training_quantization(in_model: Model, representative_da
232233
233234
"""
234235

235-
if core_config.debug_config.bypass:
236-
return in_model, None
237-
238-
KerasModelValidation(model=in_model,
239-
fw_info=DEFAULT_KERAS_INFO).validate()
240-
241-
if core_config.is_mixed_precision_enabled:
242-
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
243-
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
244-
"Ensure usage of the correct API for keras_post_training_quantization "
245-
"or provide a valid mixed-precision configuration.") # pragma: no cover
246-
247-
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
248-
249-
fw_impl = GPTQKerasImplemantation()
250-
251-
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
252-
# Attach tpc model to framework
253-
attach2keras = AttachTpcToKeras()
254-
framework_platform_capabilities = attach2keras.attach(
255-
target_platform_capabilities,
256-
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
257-
258-
progress_info_controller = ProgressInfoController(
259-
total_step=research_progress_total(core_config, target_resource_utilization, gptq_config),
260-
description="MCT Keras GPTQ Progress",
261-
progress_info_callback=core_config.debug_config.progress_info_callback
262-
)
263-
264-
tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
265-
representative_data_gen=representative_data_gen,
266-
core_config=core_config,
267-
fw_info=DEFAULT_KERAS_INFO,
268-
fw_impl=fw_impl,
269-
fqc=framework_platform_capabilities,
270-
target_resource_utilization=target_resource_utilization,
271-
tb_w=tb_w,
272-
running_gptq=True,
273-
progress_info_controller=progress_info_controller)
274-
275-
float_graph = copy.deepcopy(tg)
276-
277-
tg_gptq = gptq_runner(tg,
278-
core_config,
279-
gptq_config,
280-
representative_data_gen,
281-
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
282-
DEFAULT_KERAS_INFO,
283-
fw_impl,
284-
tb_w,
285-
hessian_info_service=hessian_info_service,
286-
progress_info_controller=progress_info_controller)
287-
288-
del hessian_info_service
289-
290-
if progress_info_controller is not None:
291-
progress_info_controller.set_description("MCT Graph Finalization")
292-
293-
if core_config.debug_config.analyze_similarity:
294-
analyzer_model_quantization(representative_data_gen,
295-
tb_w,
296-
float_graph,
297-
tg_gptq,
298-
fw_impl,
299-
DEFAULT_KERAS_INFO)
300-
301-
exportable_model, user_info = get_exportable_keras_model(tg_gptq)
302-
if framework_platform_capabilities.tpc.add_metadata:
303-
exportable_model = add_metadata(exportable_model,
304-
create_model_metadata(fqc=framework_platform_capabilities,
305-
scheduling_info=scheduling_info))
306-
307-
if progress_info_controller is not None:
308-
progress_info_controller.close()
309-
310-
return exportable_model, user_info
236+
with logging_redirect_tqdm():
237+
238+
if core_config.debug_config.bypass:
239+
return in_model, None
240+
241+
KerasModelValidation(model=in_model,
242+
fw_info=DEFAULT_KERAS_INFO).validate()
243+
244+
if core_config.is_mixed_precision_enabled:
245+
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
246+
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
247+
"Ensure usage of the correct API for keras_post_training_quantization "
248+
"or provide a valid mixed-precision configuration.") # pragma: no cover
249+
250+
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
251+
252+
fw_impl = GPTQKerasImplemantation()
253+
254+
target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
255+
# Attach tpc model to framework
256+
attach2keras = AttachTpcToKeras()
257+
framework_platform_capabilities = attach2keras.attach(
258+
target_platform_capabilities,
259+
custom_opset2layer=core_config.quantization_config.custom_tpc_opset_to_layer)
260+
261+
progress_info_controller = ProgressInfoController(
262+
total_step=research_progress_total(core_config, target_resource_utilization, gptq_config),
263+
description="MCT Keras GPTQ Progress",
264+
progress_info_callback=core_config.debug_config.progress_info_callback
265+
)
266+
267+
tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
268+
representative_data_gen=representative_data_gen,
269+
core_config=core_config,
270+
fw_info=DEFAULT_KERAS_INFO,
271+
fw_impl=fw_impl,
272+
fqc=framework_platform_capabilities,
273+
target_resource_utilization=target_resource_utilization,
274+
tb_w=tb_w,
275+
running_gptq=True,
276+
progress_info_controller=progress_info_controller)
277+
278+
float_graph = copy.deepcopy(tg)
279+
280+
tg_gptq = gptq_runner(tg,
281+
core_config,
282+
gptq_config,
283+
representative_data_gen,
284+
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
285+
DEFAULT_KERAS_INFO,
286+
fw_impl,
287+
tb_w,
288+
hessian_info_service=hessian_info_service,
289+
progress_info_controller=progress_info_controller)
290+
291+
del hessian_info_service
292+
293+
if progress_info_controller is not None:
294+
progress_info_controller.set_description("MCT Graph Finalization")
295+
296+
if core_config.debug_config.analyze_similarity:
297+
analyzer_model_quantization(representative_data_gen,
298+
tb_w,
299+
float_graph,
300+
tg_gptq,
301+
fw_impl,
302+
DEFAULT_KERAS_INFO)
303+
304+
exportable_model, user_info = get_exportable_keras_model(tg_gptq)
305+
if framework_platform_capabilities.tpc.add_metadata:
306+
exportable_model = add_metadata(exportable_model,
307+
create_model_metadata(fqc=framework_platform_capabilities,
308+
scheduling_info=scheduling_info))
309+
310+
if progress_info_controller is not None:
311+
progress_info_controller.close()
312+
313+
return exportable_model, user_info
311314

312315
else:
313316
# If tensorflow is not installed,

model_compression_toolkit/gptq/pytorch/gptq_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def micro_training_loop(self,
310310
"""
311311
with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
312312
for _ in epochs_pbar:
313-
with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
313+
with tqdm(self.train_dataloader, position=1, leave=False, disable=self.disable_data_pbar) as data_pbar:
314314
for sample in data_pbar:
315315
data, loss_weight, reg_weight = to_torch_tensor(sample)
316316
input_data = [d * self.input_scale for d in data]

0 commit comments

Comments
 (0)