@@ -498,6 +498,8 @@ def on_validation_batch_start(
498498 visible = False ,
499499 )
500500
501+ self .refresh ()
502+
501503 def _add_task (self , total_batches : Union [int , float ], description : str , visible : bool = True ) -> "TaskID" :
502504 assert self .progress is not None
503505 return self .progress .add_task (
@@ -512,22 +514,27 @@ def _initialize_train_progress_bar_id(self) -> None:
512514 self .train_progress_bar_id = self ._add_task (total_batches , train_description )
513515
514516 def _update (
515- self , progress_bar_id : Optional ["TaskID" ], current : int , visible : bool = True , refresh : bool = True
517+ self ,
518+ progress_bar_id : Optional ["TaskID" ],
519+ current : int ,
520+ visible : bool = True ,
521+ hard : bool = False ,
516522 ) -> None :
517523 if self .progress is not None and self .is_enabled and progress_bar_id is not None :
518- self .progress .update (progress_bar_id , completed = current , visible = visible , refresh = refresh )
524+ self .progress .update (progress_bar_id , completed = current , visible = visible )
525+ self .refresh (hard = hard )
519526
520527 @override
521528 def on_validation_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
522529 if self .is_enabled and self .val_progress_bar_id is not None and trainer .state .fn == "fit" :
523530 assert self .progress is not None
524- self .progress .update (self .val_progress_bar_id , advance = 0 , visible = False , refresh = True )
531+ self .progress .update (self .val_progress_bar_id , advance = 0 , visible = False )
532+ self .refresh ()
525533
526534 @override
527535 def on_validation_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
528536 if trainer .state .fn == "fit" :
529537 self ._update_metrics (trainer , pl_module )
530- self .refresh ()
531538 self .reset_dataloader_idx_tracker ()
532539
533540 @override
@@ -554,6 +561,7 @@ def on_test_batch_start(
554561 assert self .progress is not None
555562 self .progress .update (self .test_progress_bar_id , advance = 0 , visible = False )
556563 self .test_progress_bar_id = self ._add_task (self .total_test_batches_current_dataloader , self .test_description )
564+ self .refresh ()
557565
558566 @override
559567 def on_predict_batch_start (
@@ -573,6 +581,7 @@ def on_predict_batch_start(
573581 self .predict_progress_bar_id = self ._add_task (
574582 self .total_predict_batches_current_dataloader , self .predict_description
575583 )
584+ self .refresh ()
576585
577586 @override
578587 def on_train_batch_end (
@@ -586,7 +595,7 @@ def on_train_batch_end(
586595 if not self .is_disabled and self .train_progress_bar_id is None :
587596 # can happen when resuming from a mid-epoch restart
588597 self ._initialize_train_progress_bar_id ()
589- self ._update (self .train_progress_bar_id , batch_idx + 1 , refresh = False )
598+ self ._update (self .train_progress_bar_id , batch_idx + 1 )
590599 self ._update_metrics (trainer , pl_module )
591600 self .refresh ()
592601
@@ -609,12 +618,12 @@ def on_validation_batch_end(
609618 return
610619 if trainer .sanity_checking :
611620 if self .val_sanity_progress_bar_id is not None :
612- self ._update (self .val_sanity_progress_bar_id , batch_idx + 1 , refresh = True )
621+ self ._update (self .val_sanity_progress_bar_id , batch_idx + 1 )
613622 return
614623
615624 if self .val_progress_bar_id is None :
616625 return
617- self ._update (self .val_progress_bar_id , batch_idx + 1 , refresh = True )
626+ self ._update (self .val_progress_bar_id , batch_idx + 1 )
618627
619628 @override
620629 def on_test_batch_end (
@@ -628,7 +637,7 @@ def on_test_batch_end(
628637 ) -> None :
629638 if self .is_disabled or self .test_progress_bar_id is None :
630639 return
631- self ._update (self .test_progress_bar_id , batch_idx + 1 , refresh = True )
640+ self ._update (self .test_progress_bar_id , batch_idx + 1 )
632641
633642 @override
634643 def on_predict_batch_end (
@@ -642,7 +651,7 @@ def on_predict_batch_end(
642651 ) -> None :
643652 if self .is_disabled or self .predict_progress_bar_id is None :
644653 return
645- self ._update (self .predict_progress_bar_id , batch_idx + 1 , refresh = True )
654+ self ._update (self .predict_progress_bar_id , batch_idx + 1 )
646655
647656 def _get_train_description (self , current_epoch : int ) -> str :
648657 train_description = f"Epoch { current_epoch } "
0 commit comments