Skip to content

Commit afd2c34

Browse files
committed
refactor: use a general progress factory
1 parent 65c851e commit afd2c34

File tree

2 files changed

+78
-39
lines changed

2 files changed

+78
-39
lines changed

src/mega/progress.py

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,70 +2,64 @@
22

33
import asyncio
44
import contextlib
5-
from collections.abc import Generator
65
from contextvars import ContextVar
7-
from typing import TYPE_CHECKING, Literal, TypeAlias
8-
9-
from rich.progress import BarColumn, DownloadColumn, Progress, SpinnerColumn, TimeRemainingColumn, TransferSpeedColumn
6+
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias
107

118
if TYPE_CHECKING:
129
from collections.abc import Callable, Generator
10+
from types import TracebackType
1311

14-
ProgressHook: TypeAlias = Callable[[float], None]
12+
from rich.progress import Progress
1513

14+
ProgressHook: TypeAlias = Callable[[float], None]
1615

17-
_SHOW_PROGRESS = ContextVar[bool]("_SHOW_PROGRESS", default=False)
18-
_PROGRESS = ContextVar[Progress | None]("_PROGRESS", default=None)
16+
class ProgressHookContext(Protocol):
17+
def __enter__(self) -> ProgressHook: ...
1918

19+
def __exit__(
20+
self, typ: type[BaseException] | None, value: BaseException | None, traceback: TracebackType | None, /
21+
) -> Any: ...
2022

21-
def do_nothing(_: float) -> None: ...
23+
class ProgressHookFactory(Protocol):
24+
def __call__(self, description: str, total: float, kind: Literal["UP", "DOWN"]) -> ProgressHookContext: ...
2225

2326

24-
current_hook: ContextVar[ProgressHook] = ContextVar("current_hook", default=do_nothing)
27+
_PROGRESS_HOOK_FACTORY: ContextVar[ProgressHookFactory | None] = ContextVar("_PROGRESS_HOOK_FACTORY", default=None)
28+
current_hook: ContextVar[ProgressHook] = ContextVar("current_hook", default=lambda _: None)
2529

2630

2731
@contextlib.contextmanager
2832
def new_task(description: str, total: float, kind: Literal["UP", "DOWN"]) -> Generator[None]:
29-
progress = _PROGRESS.get()
30-
if progress is None:
33+
factory = _PROGRESS_HOOK_FACTORY.get()
34+
if factory is None:
3135
yield
3236
return
3337

34-
task_id = progress.add_task(description, total=total, kind=kind)
38+
with factory(description, total, kind) as progress_hook:
39+
token = current_hook.set(progress_hook)
40+
try:
41+
yield
42+
finally:
43+
current_hook.reset(token)
3544

36-
def progress_hook(advance: float) -> None:
37-
progress.advance(task_id, advance)
3845

39-
token = current_hook.set(progress_hook)
40-
try:
46+
@contextlib.contextmanager
47+
def new_progress() -> Generator[None]:
48+
progress = _new_rich_progress()
49+
if progress is None:
4150
yield
42-
finally:
43-
progress.remove_task(task_id=task_id)
44-
current_hook.reset(token)
51+
return
4552

53+
def hook_factory(*args, **kwargs):
54+
return _new_rich_task(progress, *args, **kwargs)
55+
56+
token = _PROGRESS_HOOK_FACTORY.set(hook_factory)
4657

47-
@contextlib.contextmanager
48-
def new_progress() -> Generator[None]:
49-
progress = Progress(
50-
"[{task.fields[kind]}]",
51-
SpinnerColumn(),
52-
"{task.description}",
53-
BarColumn(bar_width=None),
54-
"[progress.percentage]{task.percentage:>6.2f}%",
55-
"-",
56-
DownloadColumn(),
57-
"-",
58-
TransferSpeedColumn(),
59-
"-",
60-
TimeRemainingColumn(compact=True, elapsed_when_finished=True),
61-
transient=True,
62-
)
63-
token = _PROGRESS.set(progress)
6458
try:
6559
with progress:
6660
yield
6761
finally:
68-
_PROGRESS.reset(token)
62+
_PROGRESS_HOOK_FACTORY.reset(token)
6963

7064

7165
async def test() -> None:
@@ -90,5 +84,50 @@ async def task(name: str) -> None:
9084
tg.create_task(task(f"file{idx}"))
9185

9286

87+
def _new_rich_progress() -> Progress | None:
88+
try:
89+
from rich.progress import (
90+
BarColumn,
91+
DownloadColumn,
92+
Progress,
93+
SpinnerColumn,
94+
TimeRemainingColumn,
95+
TransferSpeedColumn,
96+
)
97+
except ImportError:
98+
return None
99+
100+
else:
101+
return Progress(
102+
"[{task.fields[kind]}]",
103+
SpinnerColumn(),
104+
"{task.description}",
105+
BarColumn(bar_width=None),
106+
"[progress.percentage]{task.percentage:>6.2f}%",
107+
"-",
108+
DownloadColumn(),
109+
"-",
110+
TransferSpeedColumn(),
111+
"-",
112+
TimeRemainingColumn(compact=True, elapsed_when_finished=True),
113+
transient=True,
114+
)
115+
116+
117+
@contextlib.contextmanager
118+
def _new_rich_task(
119+
progress: Progress, description: str, total: float, kind: Literal["UP", "DOWN"]
120+
) -> Generator[ProgressHook]:
121+
task_id = progress.add_task(description, total=total, kind=kind)
122+
123+
def progress_hook(advance: float) -> None:
124+
progress.advance(task_id, advance)
125+
126+
try:
127+
yield progress_hook
128+
finally:
129+
progress.remove_task(task_id=task_id)
130+
131+
93132
if __name__ == "__main__": # pragma: no coverage
94133
asyncio.run(test())

src/mega/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from collections.abc import Awaitable, Iterable, Sequence
88
from typing import Literal, TypeVar, overload
99

10-
from rich.logging import RichHandler
11-
1210
from mega.errors import ValidationError
1311

1412
_T = TypeVar("_T")
1513

1614

1715
def setup_logger(name: str = "mega") -> None:
16+
from rich.logging import RichHandler
17+
1818
handler = RichHandler(show_time=False, rich_tracebacks=True)
1919
logger = logging.getLogger(name)
2020
logger.setLevel(10)

0 commit comments

Comments
 (0)