diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 7bb98e8a9058c..6aec230316d43 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -447,6 +447,11 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible: visible=visible, ) + def _initialize_train_progress_bar_id(self) -> None: + total_batches = self.total_train_batches + train_description = self._get_train_description(self.trainer.current_epoch) + self.train_progress_bar_id = self._add_task(total_batches, train_description) + def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None: if self.progress is not None and self.is_enabled: assert progress_bar_id is not None @@ -531,6 +536,9 @@ def on_train_batch_end( batch: Any, batch_idx: int, ) -> None: + if not self.is_disabled and self.train_progress_bar_id is None: + # can happen when resuming from a mid-epoch restart + self._initialize_train_progress_bar_id() self._update(self.train_progress_bar_id, batch_idx + 1) self._update_metrics(trainer, pl_module) self.refresh()