1616from collections .abc import Generator
1717from dataclasses import dataclass
1818from datetime import timedelta
19- from threading import Event , Thread
2019from typing import Any , Optional , Union , cast
2120
2221import torch
3231if _RICH_AVAILABLE :
3332 from rich import get_console , reconfigure
3433 from rich .console import Console , RenderableType
35- from rich .live import Live
34+ from rich .live import _RefreshThread as _RichRefreshThread
3635 from rich .progress import BarColumn , Progress , ProgressColumn , Task , TaskID , TextColumn
3736 from rich .progress_bar import ProgressBar as _RichProgressBar
3837 from rich .style import Style
@@ -70,15 +69,10 @@ class CustomInfiniteTask(Task):
7069 def time_remaining (self ) -> Optional [float ]:
7170 return None
7271
73- class _RefreshThread (Thread ):
74- def __init__ (
75- self ,
76- live : Live ,
77- ) -> None :
78- self .live = live
72+ class _RefreshThread (_RichRefreshThread ):
73+ def __init__ (self , * args , ** kwargs ) -> None :
7974 self .refresh_cond = False
80- self .done = Event ()
81- super ().__init__ (daemon = True )
75+ super ().__init__ (* args , ** kwargs )
8276
8377 def run (self ) -> None :
8478 while not self .done .is_set ():
@@ -88,15 +82,19 @@ def run(self) -> None:
8882 self .refresh_cond = False
8983 time .sleep (0.005 )
9084
91- def stop (self ) -> None :
92- self .done .set ()
93-
9485 class CustomProgress (Progress ):
9586 """Overrides ``Progress`` to support adding tasks that have an infinite total size."""
9687
9788 def start (self ) -> None :
89+ """Starts the progress display.
90+
91+ Notes
92+ -----
93+ This override is needed to support the custom refresh thread.
94+
95+ """
9896 if self .live .auto_refresh :
99- self .live ._refresh_thread = _RefreshThread (self .live )
97+ self .live ._refresh_thread = _RefreshThread (self .live , self . live . refresh_per_second )
10098 self .live .auto_refresh = False
10199 super ().start ()
102100 if self .live ._refresh_thread :
@@ -105,7 +103,6 @@ def start(self) -> None:
105103
106104 def stop (self ) -> None :
107105 refresh_thread = self .live ._refresh_thread
108- self .live .auto_refresh = refresh_thread is not None
109106 super ().stop ()
110107 if refresh_thread :
111108 refresh_thread .stop ()
0 commit comments