22
33import asyncio
44import contextlib
5- from collections .abc import Generator
65from contextvars import ContextVar
7- from typing import TYPE_CHECKING , Literal , TypeAlias
8-
9- from rich .progress import BarColumn , DownloadColumn , Progress , SpinnerColumn , TimeRemainingColumn , TransferSpeedColumn
6+ from typing import TYPE_CHECKING , Any , Literal , Protocol , TypeAlias
107
118if TYPE_CHECKING :
129 from collections .abc import Callable , Generator
10+ from types import TracebackType
1311
14- ProgressHook : TypeAlias = Callable [[ float ], None ]
12+ from rich . progress import Progress
1513
14+ ProgressHook : TypeAlias = Callable [[float ], None ]
1615
17- _SHOW_PROGRESS = ContextVar [ bool ]( "_SHOW_PROGRESS" , default = False )
18- _PROGRESS = ContextVar [ Progress | None ]( "_PROGRESS" , default = None )
16+ class ProgressHookContext ( Protocol ):
17+ def __enter__ ( self ) -> ProgressHook : ...
1918
19+ def __exit__ (
20+ self , typ : type [BaseException ] | None , value : BaseException | None , traceback : TracebackType | None , /
21+ ) -> Any : ...
2022
21- def do_nothing (_ : float ) -> None : ...
23+ class ProgressHookFactory (Protocol ):
24+ def __call__ (self , description : str , total : float , kind : Literal ["UP" , "DOWN" ]) -> ProgressHookContext : ...
2225
2326
24- current_hook : ContextVar [ProgressHook ] = ContextVar ("current_hook" , default = do_nothing )
27+ _PROGRESS_HOOK_FACTORY : ContextVar [ProgressHookFactory | None ] = ContextVar ("_PROGRESS_HOOK_FACTORY" , default = None )
28+ current_hook : ContextVar [ProgressHook ] = ContextVar ("current_hook" , default = lambda _ : None )
2529
2630
2731@contextlib .contextmanager
2832def new_task (description : str , total : float , kind : Literal ["UP" , "DOWN" ]) -> Generator [None ]:
29- progress = _PROGRESS .get ()
30- if progress is None :
33+ factory = _PROGRESS_HOOK_FACTORY .get ()
34+ if factory is None :
3135 yield
3236 return
3337
34- task_id = progress .add_task (description , total = total , kind = kind )
38+ with factory (description , total , kind ) as progress_hook :
39+ token = current_hook .set (progress_hook )
40+ try :
41+ yield
42+ finally :
43+ current_hook .reset (token )
3544
36- def progress_hook (advance : float ) -> None :
37- progress .advance (task_id , advance )
3845
39- token = current_hook .set (progress_hook )
40- try :
46+ @contextlib .contextmanager
47+ def new_progress () -> Generator [None ]:
48+ progress = _new_rich_progress ()
49+ if progress is None :
4150 yield
42- finally :
43- progress .remove_task (task_id = task_id )
44- current_hook .reset (token )
51+ return
4552
53+ def hook_factory (* args , ** kwargs ):
54+ return _new_rich_task (progress , * args , ** kwargs )
55+
56+ token = _PROGRESS_HOOK_FACTORY .set (hook_factory )
4657
47- @contextlib .contextmanager
48- def new_progress () -> Generator [None ]:
49- progress = Progress (
50- "[{task.fields[kind]}]" ,
51- SpinnerColumn (),
52- "{task.description}" ,
53- BarColumn (bar_width = None ),
54- "[progress.percentage]{task.percentage:>6.2f}%" ,
55- "-" ,
56- DownloadColumn (),
57- "-" ,
58- TransferSpeedColumn (),
59- "-" ,
60- TimeRemainingColumn (compact = True , elapsed_when_finished = True ),
61- transient = True ,
62- )
63- token = _PROGRESS .set (progress )
6458 try :
6559 with progress :
6660 yield
6761 finally :
68- _PROGRESS .reset (token )
62+ _PROGRESS_HOOK_FACTORY .reset (token )
6963
7064
7165async def test () -> None :
@@ -90,5 +84,50 @@ async def task(name: str) -> None:
9084 tg .create_task (task (f"file{ idx } " ))
9185
9286
87+ def _new_rich_progress () -> Progress | None :
88+ try :
89+ from rich .progress import (
90+ BarColumn ,
91+ DownloadColumn ,
92+ Progress ,
93+ SpinnerColumn ,
94+ TimeRemainingColumn ,
95+ TransferSpeedColumn ,
96+ )
97+ except ImportError :
98+ return None
99+
100+ else :
101+ return Progress (
102+ "[{task.fields[kind]}]" ,
103+ SpinnerColumn (),
104+ "{task.description}" ,
105+ BarColumn (bar_width = None ),
106+ "[progress.percentage]{task.percentage:>6.2f}%" ,
107+ "-" ,
108+ DownloadColumn (),
109+ "-" ,
110+ TransferSpeedColumn (),
111+ "-" ,
112+ TimeRemainingColumn (compact = True , elapsed_when_finished = True ),
113+ transient = True ,
114+ )
115+
116+
117+ @contextlib .contextmanager
118+ def _new_rich_task (
119+ progress : Progress , description : str , total : float , kind : Literal ["UP" , "DOWN" ]
120+ ) -> Generator [ProgressHook ]:
121+ task_id = progress .add_task (description , total = total , kind = kind )
122+
123+ def progress_hook (advance : float ) -> None :
124+ progress .advance (task_id , advance )
125+
126+ try :
127+ yield progress_hook
128+ finally :
129+ progress .remove_task (task_id = task_id )
130+
131+
93132if __name__ == "__main__" : # pragma: no coverage
94133 asyncio .run (test ())
0 commit comments