Skip to content

Commit ab9fc78

Browse files
authored
Merge branch 'master' into fix/ci-rich-progress-bar
2 parents be89680 + 6a5c946 commit ab9fc78

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,11 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible:
459459
visible=visible,
460460
)
461461

462+
def _initialize_train_progress_bar_id(self) -> None:
463+
total_batches = self.total_train_batches
464+
train_description = self._get_train_description(self.trainer.current_epoch)
465+
self.train_progress_bar_id = self._add_task(total_batches, train_description)
466+
462467
def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
463468
if self.progress is not None and self.is_enabled:
464469
assert progress_bar_id is not None
@@ -543,6 +548,9 @@ def on_train_batch_end(
543548
batch: Any,
544549
batch_idx: int,
545550
) -> None:
551+
if not self.is_disabled and self.train_progress_bar_id is None:
552+
# can happen when resuming from a mid-epoch restart
553+
self._initialize_train_progress_bar_id()
546554
self._update(self.train_progress_bar_id, batch_idx + 1)
547555
self._update_metrics(trainer, pl_module)
548556
self.refresh()

0 commit comments

Comments
 (0)