Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible:
visible=visible,
)

def _initialize_progress_bar_id(self) -> None:
total_batches = self.total_train_batches
train_description = self._get_train_description(self.trainer.current_epoch)
self.train_progress_bar_id = self._add_task(total_batches, train_description)

def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
if self.progress is not None and self.is_enabled:
assert progress_bar_id is not None
Expand Down Expand Up @@ -531,6 +536,9 @@ def on_train_batch_end(
batch: Any,
batch_idx: int,
) -> None:
if self.train_progress_bar_id is None and not self.is_disabled:
# can happen when resuming from a mid-epoch restart
self._initialize_progress_bar_id()
self._update(self.train_progress_bar_id, batch_idx + 1)
self._update_metrics(trainer, pl_module)
self.refresh()
Expand Down
Loading