From 0e9a179a9ad28f5a70a903c59538692b8cf3753c Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 29 Oct 2025 00:19:44 +0800 Subject: [PATCH 1/9] init fix --- .../callbacks/progress/rich_progress.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index d4c3c916c7ed0..16efa8d42ac63 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import time from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta +from threading import Event, Thread from typing import Any, Optional, Union, cast import torch @@ -22,6 +24,7 @@ from typing_extensions import override import lightning.pytorch as pl +from lightning.fabric.utilities.imports import _IS_INTERACTIVE from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar from lightning.pytorch.utilities.imports import _RICH_AVAILABLE from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -29,6 +32,7 @@ if _RICH_AVAILABLE: from rich import get_console, reconfigure from rich.console import Console, RenderableType + from rich.live import Live from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn from rich.progress_bar import ProgressBar as _RichProgressBar from rich.style import Style @@ -66,9 +70,46 @@ class CustomInfiniteTask(Task): def time_remaining(self) -> Optional[float]: return None + class _RefreshThread(Thread): + def __init__( + self, + live: Live, + ) -> None: + self.live = live + self.refresh_cond = False + self.done = Event() + super().__init__(daemon=True) + + def run(self) -> None: + while not self.done.is_set(): + if self.refresh_cond: + with self.live._lock: + self.live.refresh() + self.refresh_cond = False + time.sleep(0.001) + + def stop(self) -> None: + self.done.set() + class CustomProgress(Progress): """Overrides ``Progress`` to support adding tasks that have an infinite total size.""" + def start(self) -> None: + if self.live.auto_refresh: + self.live._refresh_thread = _RefreshThread(self.live) + self.live.auto_refresh = False + super().start() + if self.live._refresh_thread: + self.live.auto_refresh = True + self.live._refresh_thread.start() + + def refresh(self) -> None: + if self.live.auto_refresh: + self.live._refresh_thread.refresh_cond = True + if _IS_INTERACTIVE: + return super().refresh() + return None + def add_task( self, description: str, @@ -356,7 +397,7 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, - auto_refresh=False, + auto_refresh=True, disable=self.is_disabled, console=self._console, ) From c83a8f491876c9be4c65a45680174b7a674f6ccb Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 29 Oct 2025 19:11:04 +0800 Subject: [PATCH 2/9] temp fix unittests --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 8 ++++++++ .../callbacks/progress/test_rich_progress_bar.py | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 16efa8d42ac63..8ec9be8742f92 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -103,6 +103,14 @@ def start(self) -> None: self.live.auto_refresh = True self.live._refresh_thread.start() + def stop(self) -> None: + refresh_thread = self.live._refresh_thread + self.live.auto_refresh = refresh_thread is not None + super().stop() + if refresh_thread: + refresh_thread.stop() + refresh_thread.join() + def refresh(self) -> None: if self.live.auto_refresh: self.live._refresh_thread.refresh_cond = True diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 9d74871ce84e4..7291daf5df53b 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -131,6 +131,8 @@ def test_rich_progress_bar_custom_theme(): _, kwargs = mocks["ProcessingSpeedColumn"].call_args assert kwargs["style"] == theme.processing_speed + progress_bar.progress.live._refresh_thread.stop() + @RunIf(rich=True) def test_rich_progress_bar_keyboard_interrupt(tmp_path): @@ -176,6 +178,8 @@ def configure_columns(self, trainer): assert progress_bar.progress.columns[0] == custom_column assert len(progress_bar.progress.columns) == 2 + progress_bar.progress.stop() + @RunIf(rich=True) @pytest.mark.parametrize(("leave", "reset_call_count"), ([(True, 0), (False, 3)])) From 24c1729ea96fae388ad0f8352858c26d26eb60a4 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 29 Oct 2025 20:55:58 +0800 Subject: [PATCH 3/9] release time sleep --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 8ec9be8742f92..35b581cc26893 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -86,7 +86,7 @@ def run(self) -> None: with self.live._lock: self.live.refresh() self.refresh_cond = False - time.sleep(0.001) + time.sleep(0.005) def stop(self) -> None: self.done.set() From 47c3bd9beda77631a970b7667f0587868def9289 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 29 Oct 2025 20:56:23 +0800 Subject: [PATCH 4/9] fix unittest --- tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 7291daf5df53b..a44d116d76d46 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -132,6 +132,7 @@ def test_rich_progress_bar_custom_theme(): assert kwargs["style"] == theme.processing_speed progress_bar.progress.live._refresh_thread.stop() + progress_bar.progress.live._refresh_thread.join() @RunIf(rich=True) From e3b3100bc7af834086d697a204e4edc0add64667 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 30 Oct 2025 05:48:19 +0800 Subject: [PATCH 5/9] ref soft_refresh --- .../pytorch/callbacks/progress/rich_progress.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 35b581cc26893..6d8d8f2d025b1 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -111,12 +111,9 @@ def stop(self) -> None: refresh_thread.stop() refresh_thread.join() - def refresh(self) -> None: + def soft_refresh(self) -> None: if self.live.auto_refresh: self.live._refresh_thread.refresh_cond = True - if _IS_INTERACTIVE: - return super().refresh() - return None def add_task( self, @@ -413,9 +410,12 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: # progress has started self._progress_stopped = False - def refresh(self) -> None: + def refresh(self, hard=False) -> None: if self.progress: - self.progress.refresh() + if hard or _IS_INTERACTIVE: + self.progress.refresh() + else: + self.progress.soft_refresh() @override def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: From f11e58fd9cf64b53fac3d878c51c8d884af924ff Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 01:19:24 +0800 Subject: [PATCH 6/9] fix test --- .../callbacks/progress/test_rich_progress_bar.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index a44d116d76d46..567552459cc28 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -350,7 +350,8 @@ def training_step(self, *args, **kwargs): for key in ("loss", "v_num", "train_loss"): assert key in rendered[train_progress_bar_id][1] - assert key not in rendered[val_progress_bar_id][1] + if val_progress_bar_id in rendered: + assert key not in rendered[val_progress_bar_id][1] def test_rich_progress_bar_metrics_fast_dev_run(tmp_path): @@ -364,7 +365,8 @@ def test_rich_progress_bar_metrics_fast_dev_run(tmp_path): val_progress_bar_id = progress_bar.val_progress_bar_id rendered = progress_bar.progress.columns[-1]._renderable_cache assert "v_num" not in rendered[train_progress_bar_id][1] - assert "v_num" not in rendered[val_progress_bar_id][1] + if val_progress_bar_id in rendered: + assert "v_num" not in rendered[val_progress_bar_id][1] @RunIf(rich=True) From 86d823ade38f8c8db9a32f5c183c235a77e5f58a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 01:30:57 +0800 Subject: [PATCH 7/9] refactor _RefreshThread --- .../callbacks/progress/rich_progress.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 6d8d8f2d025b1..85187b1ca1598 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -16,7 +16,6 @@ from collections.abc import Generator from dataclasses import dataclass from datetime import timedelta -from threading import Event, Thread from typing import Any, Optional, Union, cast import torch @@ -32,7 +31,7 @@ if _RICH_AVAILABLE: from rich import get_console, reconfigure from rich.console import Console, RenderableType - from rich.live import Live + from rich.live import _RefreshThread as _RichRefreshThread from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn from rich.progress_bar import ProgressBar as _RichProgressBar from rich.style import Style @@ -70,15 +69,10 @@ class CustomInfiniteTask(Task): def time_remaining(self) -> Optional[float]: return None - class _RefreshThread(Thread): - def __init__( - self, - live: Live, - ) -> None: - self.live = live + class _RefreshThread(_RichRefreshThread): + def __init__(self, *args, **kwargs) -> None: self.refresh_cond = False - self.done = Event() - super().__init__(daemon=True) + super().__init__(*args, **kwargs) def run(self) -> None: while not self.done.is_set(): @@ -88,15 +82,19 @@ def run(self) -> None: self.refresh_cond = False time.sleep(0.005) - def stop(self) -> None: - self.done.set() - class CustomProgress(Progress): """Overrides ``Progress`` to support adding tasks that have an infinite total size.""" def start(self) -> None: + """Starts the progress display. + + Notes + ----- + This override is needed to support the custom refresh thread. + + """ if self.live.auto_refresh: - self.live._refresh_thread = _RefreshThread(self.live) + self.live._refresh_thread = _RefreshThread(self.live, self.live.refresh_per_second) self.live.auto_refresh = False super().start() if self.live._refresh_thread: @@ -105,7 +103,6 @@ def start(self) -> None: def stop(self) -> None: refresh_thread = self.live._refresh_thread - self.live.auto_refresh = refresh_thread is not None super().stop() if refresh_thread: refresh_thread.stop() From 2f24cb7c1ade87d128b262e2b281054211b762d9 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 01:40:29 +0800 Subject: [PATCH 8/9] add type annotation --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 85187b1ca1598..f849164327e64 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -70,7 +70,7 @@ def time_remaining(self) -> Optional[float]: return None class _RefreshThread(_RichRefreshThread): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self.refresh_cond = False super().__init__(*args, **kwargs) @@ -407,7 +407,7 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: # progress has started self._progress_stopped = False - def refresh(self, hard=False) -> None: + def refresh(self, hard: bool = False) -> None: if self.progress: if hard or _IS_INTERACTIVE: self.progress.refresh() From daaacf9588a3d8e4e3e76949cec20d6737ac4c39 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 31 Oct 2025 20:30:36 +0800 Subject: [PATCH 9/9] add isinstance check --- src/lightning/pytorch/callbacks/progress/rich_progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index f849164327e64..15cec555a93d1 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -109,7 +109,7 @@ def stop(self) -> None: refresh_thread.join() def soft_refresh(self) -> None: - if self.live.auto_refresh: + if self.live.auto_refresh and isinstance(self.live._refresh_thread, _RefreshThread): self.live._refresh_thread.refresh_cond = True def add_task(