Skip to content

Commit 86d823a

Browse files
committed
refactor _RefreshThread
1 parent f11e58f commit 86d823a

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

src/lightning/pytorch/callbacks/progress/rich_progress.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from collections.abc import Generator
1717
from dataclasses import dataclass
1818
from datetime import timedelta
19-
from threading import Event, Thread
2019
from typing import Any, Optional, Union, cast
2120

2221
import torch
@@ -32,7 +31,7 @@
3231
if _RICH_AVAILABLE:
3332
from rich import get_console, reconfigure
3433
from rich.console import Console, RenderableType
35-
from rich.live import Live
34+
from rich.live import _RefreshThread as _RichRefreshThread
3635
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID, TextColumn
3736
from rich.progress_bar import ProgressBar as _RichProgressBar
3837
from rich.style import Style
@@ -70,15 +69,10 @@ class CustomInfiniteTask(Task):
7069
def time_remaining(self) -> Optional[float]:
7170
return None
7271

73-
class _RefreshThread(Thread):
74-
def __init__(
75-
self,
76-
live: Live,
77-
) -> None:
78-
self.live = live
72+
class _RefreshThread(_RichRefreshThread):
73+
def __init__(self, *args, **kwargs) -> None:
7974
self.refresh_cond = False
80-
self.done = Event()
81-
super().__init__(daemon=True)
75+
super().__init__(*args, **kwargs)
8276

8377
def run(self) -> None:
8478
while not self.done.is_set():
@@ -88,15 +82,19 @@ def run(self) -> None:
8882
self.refresh_cond = False
8983
time.sleep(0.005)
9084

91-
def stop(self) -> None:
92-
self.done.set()
93-
9485
class CustomProgress(Progress):
9586
"""Overrides ``Progress`` to support adding tasks that have an infinite total size."""
9687

9788
def start(self) -> None:
89+
"""Starts the progress display.
90+
91+
Notes
92+
-----
93+
This override is needed to support the custom refresh thread.
94+
95+
"""
9896
if self.live.auto_refresh:
99-
self.live._refresh_thread = _RefreshThread(self.live)
97+
self.live._refresh_thread = _RefreshThread(self.live, self.live.refresh_per_second)
10098
self.live.auto_refresh = False
10199
super().start()
102100
if self.live._refresh_thread:
@@ -105,7 +103,6 @@ def start(self) -> None:
105103

106104
def stop(self) -> None:
107105
refresh_thread = self.live._refresh_thread
108-
self.live.auto_refresh = refresh_thread is not None
109106
super().stop()
110107
if refresh_thread:
111108
refresh_thread.stop()

0 commit comments

Comments
 (0)