Skip to content

Commit ccf63c3

Browse files
committed
fix refresh rate in update metrics method
1 parent 25b1343 commit ccf63c3

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,12 +552,12 @@ def on_train_batch_end(
552552
# can happen when resuming from a mid-epoch restart
553553
self._initialize_train_progress_bar_id()
554554
self._update(self.train_progress_bar_id, batch_idx + 1)
555-
self._update_metrics(trainer, pl_module)
555+
self._update_metrics(trainer, pl_module, batch_idx + 1)
556556
self.refresh()
557557

558558
@override
559559
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
560-
self._update_metrics(trainer, pl_module)
560+
self._update_metrics(trainer, pl_module, total_batches=True)
561561

562562
@override
563563
def on_validation_batch_end(
@@ -632,7 +632,21 @@ def _reset_progress_bar_ids(self) -> None:
632632
self.test_progress_bar_id = None
633633
self.predict_progress_bar_id = None
634634

635-
def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
635+
def _update_metrics(
636+
self,
637+
trainer: "pl.Trainer",
638+
pl_module: "pl.LightningModule",
639+
current: Optional[int] = None,
640+
total_batches: bool = False,
641+
) -> None:
642+
if not self.is_enabled or self._metric_component is None:
643+
return
644+
645+
if current is not None and not total_batches:
646+
total = self.total_train_batches
647+
if not self._should_update(current, total):
648+
return
649+
636650
metrics = self.get_metrics(trainer, pl_module)
637651
if self._metric_component:
638652
self._metric_component.update(metrics)

0 commit comments

Comments
 (0)