Skip to content

Commit f77e675

Browse files
committed
fix: handle null progress bars and improve progress tracking in tqdm and rich progress bars
1 parent f7ec950 commit f77e675

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def render(self, task: "Task") -> Text:
171171
return Text()
172172
if self._trainer.training and task.id not in self._tasks:
173173
self._tasks[task.id] = "None"
174-
if self._renderable_cache:
174+
if self._renderable_cache and self._current_task_id in self._renderable_cache:
175175
self._current_task_id = cast(TaskID, self._current_task_id)
176176
self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1]
177177
self._current_task_id = task.id
@@ -185,7 +185,10 @@ def render(self, task: "Task") -> Text:
185185
def _generate_metrics_texts(self) -> Generator[str, None, None]:
186186
for name, value in self._metrics.items():
187187
if not isinstance(value, str):
188-
value = f"{value:{self._metrics_format}}"
188+
try:
189+
value = f"{value:{self._metrics_format}}"
190+
except (TypeError, ValueError):
191+
value = str(value)
189192
yield f"{name}: {value}"
190193

191194

@@ -465,17 +468,12 @@ def _initialize_train_progress_bar_id(self) -> None:
465468
self.train_progress_bar_id = self._add_task(total_batches, train_description)
466469

467470
def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
468-
if self.progress is not None and self.is_enabled:
469-
assert progress_bar_id is not None
471+
if self.progress is not None and self.is_enabled and progress_bar_id is not None:
470472
total = self.progress.tasks[progress_bar_id].total
471473
assert total is not None
472474
if not self._should_update(current, total):
473475
return
474-
475-
leftover = current % self.refresh_rate
476-
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
477-
self.progress.update(progress_bar_id, advance=advance, visible=visible)
478-
self.refresh()
476+
self.progress.update(progress_bar_id, completed=current, visible=visible)
479477

480478
def _should_update(self, current: int, total: Union[int, float]) -> bool:
481479
return current % self.refresh_rate == 0 or current == total
@@ -572,9 +570,13 @@ def on_validation_batch_end(
572570
if self.is_disabled:
573571
return
574572
if trainer.sanity_checking:
575-
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
576-
elif self.val_progress_bar_id is not None:
577-
self._update(self.val_progress_bar_id, batch_idx + 1)
573+
if self.val_sanity_progress_bar_id is not None:
574+
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
575+
return
576+
577+
if self.val_progress_bar_id is None:
578+
return
579+
self._update(self.val_progress_bar_id, batch_idx + 1)
578580
self.refresh()
579581

580582
@override
@@ -587,9 +589,8 @@ def on_test_batch_end(
587589
batch_idx: int,
588590
dataloader_idx: int = 0,
589591
) -> None:
590-
if self.is_disabled:
592+
if self.is_disabled or self.test_progress_bar_id is None:
591593
return
592-
assert self.test_progress_bar_id is not None
593594
self._update(self.test_progress_bar_id, batch_idx + 1)
594595
self.refresh()
595596

@@ -603,9 +604,8 @@ def on_predict_batch_end(
603604
batch_idx: int,
604605
dataloader_idx: int = 0,
605606
) -> None:
606-
if self.is_disabled:
607+
if self.is_disabled or self.predict_progress_bar_id is None:
607608
return
608-
assert self.predict_progress_bar_id is not None
609609
self._update(self.predict_progress_bar_id, batch_idx + 1)
610610
self.refresh()
611611

src/lightning/pytorch/callbacks/progress/tqdm_progress.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def on_train_batch_end(
274274
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
275275
) -> None:
276276
n = batch_idx + 1
277-
if self._should_update(n, self.train_progress_bar.total):
277+
if self.train_progress_bar is not None and self._should_update(n, self.train_progress_bar.total):
278278
_update_n(self.train_progress_bar, n)
279279
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
280280

@@ -322,7 +322,7 @@ def on_validation_batch_end(
322322
dataloader_idx: int = 0,
323323
) -> None:
324324
n = batch_idx + 1
325-
if self._should_update(n, self.val_progress_bar.total):
325+
if self.val_progress_bar is not None and self._should_update(n, self.val_progress_bar.total):
326326
_update_n(self.val_progress_bar, n)
327327

328328
@override
@@ -363,7 +363,7 @@ def on_test_batch_end(
363363
dataloader_idx: int = 0,
364364
) -> None:
365365
n = batch_idx + 1
366-
if self._should_update(n, self.test_progress_bar.total):
366+
if self.test_progress_bar is not None and self._should_update(n, self.test_progress_bar.total):
367367
_update_n(self.test_progress_bar, n)
368368

369369
@override
@@ -402,7 +402,7 @@ def on_predict_batch_end(
402402
dataloader_idx: int = 0,
403403
) -> None:
404404
n = batch_idx + 1
405-
if self._should_update(n, self.predict_progress_bar.total):
405+
if self.predict_progress_bar is not None and self._should_update(n, self.predict_progress_bar.total):
406406
_update_n(self.predict_progress_bar, n)
407407

408408
@override

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def get_metrics(self, trainer, pl_module):
707707
del items["v_num"]
708708
# this is equivalent to mocking `set_postfix` as this method gets called every time
709709
self.calls[trainer.state.fn].append((
710-
trainer.state.stage,
710+
trainer.state.stage.value,
711711
trainer.current_epoch,
712712
trainer.global_step,
713713
items,

0 commit comments

Comments
 (0)