Skip to content

Commit 446a1b5

Browse files
kuynzerebwilliamFalcon
authored andcommitted
Split progress bar (#449)
* Splitted progress bars * Iterable dataset total batches fix * Use dynamic ncols and use batch as units * Count epochs from 1 in progress bar * Fix for disabled progress bar * Code simplifications
1 parent 4e9fd95 commit 446a1b5

File tree

3 files changed

+61
-35
lines changed

3 files changed

+61
-35
lines changed

pytorch_lightning/trainer/evaluation_loop_mixin.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import tqdm
23

34
from pytorch_lightning.utilities.debugging import MisconfigurationException
45

@@ -52,8 +53,11 @@ def evaluate(self, model, dataloaders, max_batches, test=False):
5253
dl_outputs.append(output)
5354

5455
# batch done
55-
if self.show_progress_bar:
56-
self.progress_bar.update(1)
56+
if test:
57+
self.test_progress_bar.update(1)
58+
else:
59+
self.val_progress_bar.update(1)
60+
self.main_progress_bar.update(1)
5761
outputs.append(dl_outputs)
5862

5963
eval_results = {}
@@ -110,6 +114,15 @@ def run_evaluation(self, test=False):
110114
if self.fast_dev_run:
111115
max_batches = 1
112116

117+
# init validation or test progress bar
118+
# main progress bar will already be closed when testing so initial position is free
119+
position = 2 * self.process_position + (not test)
120+
desc = 'Testing' if test else 'Validating'
121+
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
122+
disable=not self.show_progress_bar, dynamic_ncols=True,
123+
unit='batch')
124+
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
125+
113126
# run evaluation
114127
eval_results = self.evaluate(self.model,
115128
dataloaders,
@@ -130,10 +143,16 @@ def run_evaluation(self, test=False):
130143
# hook
131144
model.on_post_performance_check()
132145

133-
if self.show_progress_bar:
134-
# add model specific metrics
135-
tqdm_metrics = self.training_tqdm_dict
136-
self.progress_bar.set_postfix(**tqdm_metrics)
146+
# add model specific metrics
147+
tqdm_metrics = self.training_tqdm_dict
148+
if not test:
149+
self.main_progress_bar.set_postfix(**tqdm_metrics)
150+
151+
# close progress bar
152+
if test:
153+
self.test_progress_bar.close()
154+
else:
155+
self.val_progress_bar.close()
137156

138157
# model checkpointing
139158
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:

pytorch_lightning/trainer/train_loop_mixin.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import tqdm
23

34
try:
45
from apex import amp
@@ -34,18 +35,21 @@ def train(self):
3435
self.nb_val_batches * val_checks_per_epoch)
3536
self.batch_loss_value = 0 # accumulated grads
3637

37-
# limit the number of batches to 1 in fast_dev_run
3838
if self.fast_dev_run:
39-
self.total_batches = 1
40-
41-
# init progress_bar when requested
42-
if self.show_progress_bar:
39+
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
40+
nb_iterations = 2
41+
elif self.is_iterable_train_dataloader:
42+
# for iterable train loader, the progress bar never ends
43+
nb_iterations = None
44+
else:
4345
nb_iterations = self.total_batches
4446

45-
# for iterable train loader, the progress bar never ends
46-
if self.is_iterable_train_dataloader:
47-
nb_iterations = float('inf')
48-
self.progress_bar.reset(nb_iterations)
47+
# reset progress bar
48+
# .reset() doesn't work on disabled progress bar so we should check
49+
if not self.main_progress_bar.disable:
50+
self.main_progress_bar.reset(nb_iterations)
51+
desc = f'Epoch {epoch_nb + 1}' if not self.is_iterable_train_dataloader else ''
52+
self.main_progress_bar.set_description(desc)
4953

5054
# changing gradient according accumulation_scheduler
5155
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
@@ -68,8 +72,11 @@ def train(self):
6872
# stop training
6973
stop = should_stop and met_min_epochs
7074
if stop:
75+
self.main_progress_bar.close()
7176
return
7277

78+
self.main_progress_bar.close()
79+
7380
if self.logger is not None:
7481
self.logger.finalize("success")
7582

@@ -158,9 +165,6 @@ def run_training_batch(self, batch, batch_nb):
158165
if response == -1:
159166
return -1, grad_norm_dic
160167

161-
if self.show_progress_bar:
162-
self.progress_bar.update(1)
163-
164168
splits = [batch]
165169
if self.truncated_bptt_steps is not None:
166170
model_ref = self.get_model()
@@ -241,17 +245,15 @@ def optimizer_closure():
241245
self.batch_loss_value = 0
242246
self.avg_loss = np.mean(self.running_loss[-100:])
243247

244-
# update progress bar
245-
if self.show_progress_bar:
246-
# add model specific metrics
247-
tqdm_metrics = self.training_tqdm_dict
248-
self.progress_bar.set_postfix(**tqdm_metrics)
249-
250248
# activate batch end hook
251249
if self.is_function_implemented('on_batch_end'):
252250
model = self.get_model()
253251
model.on_batch_end()
254252

253+
# update progress bar
254+
self.main_progress_bar.update(1)
255+
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
256+
255257
# collapse all metrics into one dict
256258
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
257259

pytorch_lightning/trainer/trainer.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,6 @@ def training_tqdm_dict(self):
296296
"""
297297
tqdm_dict = {
298298
'loss': '{0:.3f}'.format(self.avg_loss),
299-
'epoch': '{}'.format(self.current_epoch),
300299
'batch_nb': '{}'.format(self.batch_nb),
301300
}
302301

@@ -432,28 +431,34 @@ def run_pretrain_routine(self, model):
432431
# restore training and model before hpc call
433432
self.restore_weights(model)
434433

435-
# progress bar init
436-
if self.show_progress_bar:
437-
self.progress_bar = tqdm.tqdm(0, position=self.process_position)
438-
439434
# when testing requested only run test and return
440435
if self.testing:
441-
if self.show_progress_bar:
442-
self.progress_bar.reset(self.nb_test_batches)
443-
444436
self.run_evaluation(test=True)
445437
return
446438

447439
# run tiny validation (if validation defined)
448440
# to make sure program won't crash during val
449441
ref_model.on_sanity_check_start()
450442
if self.get_val_dataloaders() is not None and self.nb_sanity_val_steps > 0:
451-
# reset progress_bar limit for sanity check
452-
if self.show_progress_bar:
453-
self.progress_bar.reset(self.nb_sanity_val_steps)
443+
# init progress bars for validation sanity check
444+
pbar = tqdm.tqdm(desc='Validation sanity check', total=self.nb_sanity_val_steps,
445+
leave=False, position=2 * self.process_position,
446+
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
447+
self.main_progress_bar = pbar
448+
# dummy validation progress bar
449+
self.val_progress_bar = tqdm.tqdm(disable=True)
454450

455451
self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing)
456452

453+
# close progress bars
454+
self.main_progress_bar.close()
455+
self.val_progress_bar.close()
456+
457+
# init progress bar
458+
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
459+
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
460+
self.main_progress_bar = pbar
461+
457462
# clear cache before training
458463
if self.on_gpu:
459464
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)