@@ -552,12 +552,12 @@ def on_train_batch_end(
552552 # can happen when resuming from a mid-epoch restart
553553 self ._initialize_train_progress_bar_id ()
554554 self ._update (self .train_progress_bar_id , batch_idx + 1 )
555- self ._update_metrics (trainer , pl_module )
555+ self ._update_metrics (trainer , pl_module , batch_idx + 1 )
556556 self .refresh ()
557557
558558 @override
559559 def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
560- self ._update_metrics (trainer , pl_module )
560+ self ._update_metrics (trainer , pl_module , total_batches = True )
561561
562562 @override
563563 def on_validation_batch_end (
@@ -632,7 +632,21 @@ def _reset_progress_bar_ids(self) -> None:
632632 self .test_progress_bar_id = None
633633 self .predict_progress_bar_id = None
634634
635- def _update_metrics (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
635+ def _update_metrics (
636+ self ,
637+ trainer : "pl.Trainer" ,
638+ pl_module : "pl.LightningModule" ,
639+ current : Optional [int ] = None ,
640+ total_batches : bool = False ,
641+ ) -> None :
642+ if not self .is_enabled or self ._metric_component is None :
643+ return
644+
645+ if current is not None and not total_batches :
646+ total = self .total_train_batches
647+ if not self ._should_update (current , total ):
648+ return
649+
636650 metrics = self .get_metrics (trainer , pl_module )
637651 if self ._metric_component :
638652 self ._metric_component .update (metrics )
0 commit comments