Skip to content

Commit 5a095d2

Browse files
authored
Fix refresh rate for metrics in RichProgressBar (#21032)
* fix refresh rate in update metrics method * add testing * changelog
1 parent 791753b commit 5a095d2

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- fix `AdvancedProfiler` to handle nested profiling actions for Python 3.12+ ([#20809](https://github.com/Lightning-AI/pytorch-lightning/pull/20809))
3131

3232

33+
- Fixed metrics in `RichProgressBar` being updated according to user provided `refresh_rate` ([#21032](https://github.com/Lightning-AI/pytorch-lightning/pull/21032))
34+
3335
---
3436

3537
## [2.5.2] - 2025-06-20

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,12 +552,12 @@ def on_train_batch_end(
552552
# can happen when resuming from a mid-epoch restart
553553
self._initialize_train_progress_bar_id()
554554
self._update(self.train_progress_bar_id, batch_idx + 1)
555-
self._update_metrics(trainer, pl_module)
555+
self._update_metrics(trainer, pl_module, batch_idx + 1)
556556
self.refresh()
557557

558558
@override
559559
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
560-
self._update_metrics(trainer, pl_module)
560+
self._update_metrics(trainer, pl_module, total_batches=True)
561561

562562
@override
563563
def on_validation_batch_end(
@@ -632,7 +632,21 @@ def _reset_progress_bar_ids(self) -> None:
632632
self.test_progress_bar_id = None
633633
self.predict_progress_bar_id = None
634634

635-
def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
635+
def _update_metrics(
636+
self,
637+
trainer: "pl.Trainer",
638+
pl_module: "pl.LightningModule",
639+
current: Optional[int] = None,
640+
total_batches: bool = False,
641+
) -> None:
642+
if not self.is_enabled or self._metric_component is None:
643+
return
644+
645+
if current is not None and not total_batches:
646+
total = self.total_train_batches
647+
if not self._should_update(current, total):
648+
return
649+
636650
metrics = self.get_metrics(trainer, pl_module)
637651
if self._metric_component:
638652
self._metric_component.update(metrics)

tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batch
246246
with mock.patch.object(
247247
trainer.progress_bar_callback.progress, "update", wraps=trainer.progress_bar_callback.progress.update
248248
) as progress_update:
249+
metrics_update = mock.MagicMock()
250+
trainer.progress_bar_callback._update_metrics = metrics_update
251+
249252
trainer.fit(model)
250253
assert progress_update.call_count == expected_call_count
251254

@@ -260,6 +263,9 @@ def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batch
260263
assert fit_val_bar.total == val_batches
261264
assert not fit_val_bar.visible
262265

266+
# one call for each train batch + one at the end of training epoch + one for validation end
267+
assert metrics_update.call_count == train_batches + (1 if train_batches > 0 else 0) + (1 if val_batches > 0 else 0)
268+
263269

264270
@RunIf(rich=True)
265271
@pytest.mark.parametrize("limit_val_batches", [1, 5])

0 commit comments

Comments
 (0)