diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 0f11b19c23431..ab3a36f7dad3b 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -6,6 +6,6 @@ matplotlib>3.1, <3.10.0 omegaconf >=2.2.3, <2.4.0 hydra-core >=1.2.0, <1.4.0 jsonargparse[signatures,jsonnet] >=4.39.0, <4.41.0 -rich >=12.3.0, <14.1.0 +rich >=12.3.0, <14.2.0 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin" diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5b364ac1c7a3e..81c7bfc656885 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- fix progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016)) --- diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 6aec230316d43..ff092fa99d825 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -331,7 +331,19 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: self._reset_progress_bar_ids() reconfigure(**self._console_kwargs) self._console = get_console() - self._console.clear_live() + + # Compatibility shim for Rich >= 14.1.0: + if hasattr(self._console, "_live_stack"): + # In recent Rich releases, the internal `_live` variable was replaced with `_live_stack` (a list) + # to support nested Live displays. This broke our original call to `clear_live()`, + # because it now only pops one Live instance instead of clearing them all. + # We check for `_live_stack` and clear it manually for compatibility across + # both old and new Rich versions. + if len(self._console._live_stack) > 0: + self._console.clear_live() + else: + self._console.clear_live() + self._metric_component = MetricsTextColumn( trainer, self.theme.metrics,