@@ -357,13 +357,6 @@ def refresh(self) -> None:
357357 def on_train_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
358358 self ._init_progress (trainer )
359359
360- # Initialize the training progress bar here because
361- # `on_train_epoch_start` is not called when resuming from a mid-epoch restart
362- total_batches = self .total_train_batches
363- train_description = self ._get_train_description (trainer .current_epoch )
364- assert self .progress is not None
365- self .train_progress_bar_id = self ._add_task (total_batches , train_description )
366-
367360 @override
368361 def on_predict_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
369362 self ._init_progress (trainer )
@@ -454,6 +447,14 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible:
454447 visible = visible ,
455448 )
456449
450+ def _initialize_progress_bar_id (self ) -> None :
451+ # Initialize the training progress bar here because
452+ # `on_train_epoch_start` is not called when resuming from a mid-epoch restart
453+ total_batches = self .total_train_batches
454+ train_description = self ._get_train_description (self .trainer .current_epoch )
455+ assert self .progress is not None
456+ self .train_progress_bar_id = self ._add_task (total_batches , train_description )
457+
457458 def _update (self , progress_bar_id : Optional ["TaskID" ], current : int , visible : bool = True ) -> None :
458459 if self .progress is not None and self .is_enabled :
459460 assert progress_bar_id is not None
@@ -538,6 +539,9 @@ def on_train_batch_end(
538539 batch : Any ,
539540 batch_idx : int ,
540541 ) -> None :
542+ if self .train_progress_bar_id is None :
543+ # can happen when resuming from a mid-epoch restart
544+ self ._initialize_progress_bar_id ()
541545 self ._update (self .train_progress_bar_id , batch_idx + 1 )
542546 self ._update_metrics (trainer , pl_module )
543547 self .refresh ()
0 commit comments