diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 96f85265cc6ac..10d90d68fcd45 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -39,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed misalignment column while using rich model summary in `DeepSpeedstrategy` ([#21100](https://github.com/Lightning-AI/pytorch-lightning/pull/21100)) + +- Fixed `RichProgressBar` crashing when sanity checking using val dataloader with 0 len ([#21108](https://github.com/Lightning-AI/pytorch-lightning/pull/21108)) + --- ## [2.5.3] - 2025-08-13 diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index c435c810d94b5..d4c3c916c7ed0 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -390,8 +390,7 @@ def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningM @override def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.progress is not None: - assert self.val_sanity_progress_bar_id is not None + if self.progress is not None and self.val_sanity_progress_bar_id is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.refresh() diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 639414a797aa0..9d74871ce84e4 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -577,3 +577,31 @@ def test_rich_progress_bar_metrics_theme_update(*_): theme = RichProgressBar(theme=RichProgressBarTheme(metrics_format=".3e", metrics_text_delimiter="\n")).theme assert theme.metrics_format == ".3e" assert theme.metrics_text_delimiter == "\n" + + +@RunIf(rich=True) +def test_rich_progress_bar_empty_val_dataloader_model(tmp_path): + """Test that RichProgressBar doesn't crash with empty val_dataloader list from model.""" + + class EmptyListModel(BoringModel): + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=2) + + def val_dataloader(self): + return [] + + model = EmptyListModel() + progress_bar = RichProgressBar() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + num_sanity_val_steps=1, + callbacks=[progress_bar], + limit_train_batches=2, + enable_checkpointing=False, + logger=False, + ) + + # This should not raise an AssertionError + trainer.fit(model)