Skip to content

Commit f576ed3

Browse files
authored
Fix resuming the tqdm progress bar (#13962)
1 parent ff2e329 commit f576ed3

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
408408
- Fixed Python 3.10 compatibility for truncated back-propagation through time (TBPTT) ([#13973](https://github.com/Lightning-AI/lightning/pull/13973))
409409

410410

411+
- Fixed `TQDMProgressBar` reset and update to show correct time estimation (2/2) ([#13962](https://github.com/Lightning-AI/lightning/pull/13962))
412+
413+
411414

412415
## [1.6.5] - 2022-07-13
413416

@@ -463,7 +466,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
463466
- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/Lightning-AI/lightning/pull/12891))
464467
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/Lightning-AI/lightning/pull/12653))
465468
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/Lightning-AI/lightning/pull/12965))
466-
- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/Lightning-AI/lightning/pull/12889))
469+
- Fixed `TQDMProgressBar` reset and update to show correct time estimation (1/2) ([#12889](https://github.com/Lightning-AI/lightning/pull/12889))
467470
- Fixed fit loop restart logic to enable resume using the checkpoint ([#12821](https://github.com/Lightning-AI/lightning/pull/12821))
468471

469472

src/pytorch_lightning/callbacks/progress/tqdm_progress.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,13 @@ def on_train_start(self, *_: Any) -> None:
254254
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
255255
total_batches = self.total_batches_current_epoch
256256
self.main_progress_bar.reset(convert_inf(total_batches))
257+
self.main_progress_bar.initial = 0
257258
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
258259

259260
def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
260261
current = self.train_batch_idx + self._val_processed
261262
if self._should_update(current, self.main_progress_bar.total):
262-
_update_n(self.main_progress_bar, current, self.refresh_rate)
263+
_update_n(self.main_progress_bar, current)
263264
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
264265

265266
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@@ -280,16 +281,17 @@ def on_validation_batch_start(
280281
return
281282

282283
self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
284+
self.val_progress_bar.initial = 0
283285
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
284286
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
285287

286288
def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
287289
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
288-
_update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate)
290+
_update_n(self.val_progress_bar, self.val_batch_idx)
289291

290292
current = self.train_batch_idx + self._val_processed
291293
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
292-
_update_n(self.main_progress_bar, current, self.refresh_rate)
294+
_update_n(self.main_progress_bar, current)
293295

294296
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
295297
if self._main_progress_bar is not None and trainer.state.fn == "fit":
@@ -307,11 +309,12 @@ def on_test_batch_start(
307309
return
308310

309311
self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
312+
self.test_progress_bar.initial = 0
310313
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
311314

312315
def on_test_batch_end(self, *_: Any) -> None:
313316
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
314-
_update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate)
317+
_update_n(self.test_progress_bar, self.test_batch_idx)
315318

316319
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
317320
self.test_progress_bar.close()
@@ -327,11 +330,12 @@ def on_predict_batch_start(
327330
return
328331

329332
self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
333+
self.predict_progress_bar.initial = 0
330334
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
331335

332336
def on_predict_batch_end(self, *_: Any) -> None:
333337
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
334-
_update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate)
338+
_update_n(self.predict_progress_bar, self.predict_batch_idx)
335339

336340
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
337341
self.predict_progress_bar.close()
@@ -375,9 +379,7 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
375379
return x
376380

377381

378-
def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None:
382+
def _update_n(bar: _tqdm, value: int) -> None:
379383
if not bar.disable:
380-
total = bar.total
381-
leftover = current % refresh_rate
382-
advance = leftover if (current == total and leftover != 0) else refresh_rate
383-
bar.update(advance)
384+
bar.n = value
385+
bar.refresh()

0 commit comments

Comments
 (0)