Skip to content

Commit 1dcfbbf

Browse files
committed
update
1 parent 242fb06 commit 1dcfbbf

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,13 +357,6 @@ def refresh(self) -> None:
357357
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
358358
self._init_progress(trainer)
359359

360-
# Initialize the training progress bar here because
361-
# `on_train_epoch_start` is not called when resuming from a mid-epoch restart
362-
total_batches = self.total_train_batches
363-
train_description = self._get_train_description(trainer.current_epoch)
364-
assert self.progress is not None
365-
self.train_progress_bar_id = self._add_task(total_batches, train_description)
366-
367360
@override
368361
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
369362
self._init_progress(trainer)
@@ -454,6 +447,14 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible:
454447
visible=visible,
455448
)
456449

450+
def _initialize_progress_bar_id(self) -> None:
451+
# Initialize the training progress bar here because
452+
# `on_train_epoch_start` is not called when resuming from a mid-epoch restart
453+
total_batches = self.total_train_batches
454+
train_description = self._get_train_description(self.trainer.current_epoch)
455+
assert self.progress is not None
456+
self.train_progress_bar_id = self._add_task(total_batches, train_description)
457+
457458
def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
458459
if self.progress is not None and self.is_enabled:
459460
assert progress_bar_id is not None
@@ -538,6 +539,9 @@ def on_train_batch_end(
538539
batch: Any,
539540
batch_idx: int,
540541
) -> None:
542+
if self.train_progress_bar_id is None:
543+
# can happen when resuming from a mid-epoch restart
544+
self._initialize_progress_bar_id()
541545
self._update(self.train_progress_bar_id, batch_idx + 1)
542546
self._update_metrics(trainer, pl_module)
543547
self.refresh()

0 commit comments

Comments
 (0)