diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index d4c3c916c7ed0..15cec555a93d1 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import time from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta @@ -22,6 +23,7 @@ from typing_extensions import override import lightning.pytorch as pl +from lightning.fabric.utilities.imports import _IS_INTERACTIVE from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar from lightning.pytorch.utilities.imports import _RICH_AVAILABLE from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -29,6 +31,7 @@ if _RICH_AVAILABLE: from rich import get_console, reconfigure from rich.console import Console, RenderableType + from rich.live import _RefreshThread as _RichRefreshThread from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn from rich.progress_bar import ProgressBar as _RichProgressBar from rich.style import Style @@ -66,9 +69,49 @@ class CustomInfiniteTask(Task): def time_remaining(self) -> Optional[float]: return None + class _RefreshThread(_RichRefreshThread): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.refresh_cond = False + super().__init__(*args, **kwargs) + + def run(self) -> None: + while not self.done.is_set(): + if self.refresh_cond: + with self.live._lock: + self.live.refresh() + self.refresh_cond = False + time.sleep(0.005) + class CustomProgress(Progress): """Overrides ``Progress`` to support adding tasks that have an infinite total size.""" + def start(self) -> None: + """Starts the progress display. + + Notes + ----- + This override is needed to support the custom refresh thread. + + """ + if self.live.auto_refresh: + self.live._refresh_thread = _RefreshThread(self.live, self.live.refresh_per_second) + self.live.auto_refresh = False + super().start() + if self.live._refresh_thread: + self.live.auto_refresh = True + self.live._refresh_thread.start() + + def stop(self) -> None: + refresh_thread = self.live._refresh_thread + super().stop() + if refresh_thread: + refresh_thread.stop() + refresh_thread.join() + + def soft_refresh(self) -> None: + if self.live.auto_refresh and isinstance(self.live._refresh_thread, _RefreshThread): + self.live._refresh_thread.refresh_cond = True + def add_task( self, description: str, @@ -356,7 +399,7 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, - auto_refresh=False, + auto_refresh=True, disable=self.is_disabled, console=self._console, ) @@ -364,9 +407,12 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: # progress has started self._progress_stopped = False - def refresh(self) -> None: + def refresh(self, hard: bool = False) -> None: if self.progress: - self.progress.refresh() + if hard or _IS_INTERACTIVE: + self.progress.refresh() + else: + self.progress.soft_refresh() @override def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 9d74871ce84e4..567552459cc28 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -131,6 +131,9 @@ def test_rich_progress_bar_custom_theme(): _, kwargs = mocks["ProcessingSpeedColumn"].call_args assert kwargs["style"] == theme.processing_speed + progress_bar.progress.live._refresh_thread.stop() + progress_bar.progress.live._refresh_thread.join() + @RunIf(rich=True) def test_rich_progress_bar_keyboard_interrupt(tmp_path): @@ -176,6 +179,8 @@ def configure_columns(self, trainer): assert progress_bar.progress.columns[0] == custom_column assert len(progress_bar.progress.columns) == 2 + progress_bar.progress.stop() + @RunIf(rich=True) @pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 3)])) @@ -345,7 +350,8 @@ def training_step(self, *args, **kwargs): for key in ("loss", "v_num", "train_loss"): assert key in rendered[train_progress_bar_id][1] - assert key not in rendered[val_progress_bar_id][1] + if val_progress_bar_id in rendered: + assert key not in rendered[val_progress_bar_id][1] def test_rich_progress_bar_metrics_fast_dev_run(tmp_path): @@ -359,7 +365,8 @@ def test_rich_progress_bar_metrics_fast_dev_run(tmp_path): val_progress_bar_id = progress_bar.val_progress_bar_id rendered = progress_bar.progress.columns[-1]._renderable_cache assert "v_num" not in rendered[train_progress_bar_id][1] - assert "v_num" not in rendered[val_progress_bar_id][1] + if val_progress_bar_id in rendered: + assert "v_num" not in rendered[val_progress_bar_id][1] @RunIf(rich=True)