Skip to content

Commit 04b8870

Browse files
committed
fix: handle null progress bars and improve progress tracking in tqdm and rich progress bars
1 parent b6c1772 commit 04b8870

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def render(self, task: "Task") -> Text:
171171
return Text()
172172
if self._trainer.training and task.id not in self._tasks:
173173
self._tasks[task.id] = "None"
174-
if self._renderable_cache:
174+
if self._renderable_cache and self._current_task_id in self._renderable_cache:
175175
self._current_task_id = cast(TaskID, self._current_task_id)
176176
self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1]
177177
self._current_task_id = task.id
@@ -185,7 +185,10 @@ def render(self, task: "Task") -> Text:
185185
def _generate_metrics_texts(self) -> Generator[str, None, None]:
186186
for name, value in self._metrics.items():
187187
if not isinstance(value, str):
188-
value = f"{value:{self._metrics_format}}"
188+
try:
189+
value = f"{value:{self._metrics_format}}"
190+
except (TypeError, ValueError):
191+
value = str(value)
189192
yield f"{name}: {value}"
190193

191194

@@ -448,17 +451,12 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible:
448451
)
449452

450453
def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
451-
if self.progress is not None and self.is_enabled:
452-
assert progress_bar_id is not None
454+
if self.progress is not None and self.is_enabled and progress_bar_id is not None:
453455
total = self.progress.tasks[progress_bar_id].total
454456
assert total is not None
455457
if not self._should_update(current, total):
456458
return
457-
458-
leftover = current % self.refresh_rate
459-
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
460-
self.progress.update(progress_bar_id, advance=advance, visible=visible)
461-
self.refresh()
459+
self.progress.update(progress_bar_id, completed=current, visible=visible)
462460

463461
def _should_update(self, current: int, total: Union[int, float]) -> bool:
464462
return current % self.refresh_rate == 0 or current == total
@@ -552,9 +550,13 @@ def on_validation_batch_end(
552550
if self.is_disabled:
553551
return
554552
if trainer.sanity_checking:
555-
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
556-
elif self.val_progress_bar_id is not None:
557-
self._update(self.val_progress_bar_id, batch_idx + 1)
553+
if self.val_sanity_progress_bar_id is not None:
554+
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
555+
return
556+
557+
if self.val_progress_bar_id is None:
558+
return
559+
self._update(self.val_progress_bar_id, batch_idx + 1)
558560
self.refresh()
559561

560562
@override
@@ -567,9 +569,8 @@ def on_test_batch_end(
567569
batch_idx: int,
568570
dataloader_idx: int = 0,
569571
) -> None:
570-
if self.is_disabled:
572+
if self.is_disabled or self.test_progress_bar_id is None:
571573
return
572-
assert self.test_progress_bar_id is not None
573574
self._update(self.test_progress_bar_id, batch_idx + 1)
574575
self.refresh()
575576

@@ -583,9 +584,8 @@ def on_predict_batch_end(
583584
batch_idx: int,
584585
dataloader_idx: int = 0,
585586
) -> None:
586-
if self.is_disabled:
587+
if self.is_disabled or self.predict_progress_bar_id is None:
587588
return
588-
assert self.predict_progress_bar_id is not None
589589
self._update(self.predict_progress_bar_id, batch_idx + 1)
590590
self.refresh()
591591

src/lightning/pytorch/callbacks/progress/tqdm_progress.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def on_train_batch_end(
274274
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
275275
) -> None:
276276
n = batch_idx + 1
277-
if self._should_update(n, self.train_progress_bar.total):
277+
if self.train_progress_bar is not None and self._should_update(n, self.train_progress_bar.total):
278278
_update_n(self.train_progress_bar, n)
279279
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
280280

@@ -322,7 +322,7 @@ def on_validation_batch_end(
322322
dataloader_idx: int = 0,
323323
) -> None:
324324
n = batch_idx + 1
325-
if self._should_update(n, self.val_progress_bar.total):
325+
if self.val_progress_bar is not None and self._should_update(n, self.val_progress_bar.total):
326326
_update_n(self.val_progress_bar, n)
327327

328328
@override
@@ -363,7 +363,7 @@ def on_test_batch_end(
363363
dataloader_idx: int = 0,
364364
) -> None:
365365
n = batch_idx + 1
366-
if self._should_update(n, self.test_progress_bar.total):
366+
if self.test_progress_bar is not None and self._should_update(n, self.test_progress_bar.total):
367367
_update_n(self.test_progress_bar, n)
368368

369369
@override
@@ -402,7 +402,7 @@ def on_predict_batch_end(
402402
dataloader_idx: int = 0,
403403
) -> None:
404404
n = batch_idx + 1
405-
if self._should_update(n, self.predict_progress_bar.total):
405+
if self.predict_progress_bar is not None and self._should_update(n, self.predict_progress_bar.total):
406406
_update_n(self.predict_progress_bar, n)
407407

408408
@override

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def get_metrics(self, trainer, pl_module):
707707
del items["v_num"]
708708
# this is equivalent to mocking `set_postfix` as this method gets called every time
709709
self.calls[trainer.state.fn].append((
710-
trainer.state.stage,
710+
trainer.state.stage.value,
711711
trainer.current_epoch,
712712
trainer.global_step,
713713
items,

0 commit comments

Comments
 (0)