|
| 1 | +""" |
| 2 | +A class to manage [workers](/guide/workers) for an app. |
| 3 | +
|
| 4 | +You access this object via [App.workers][textual.app.App.workers] or [Widget.workers][textual.dom.DOMNode.workers]. |
| 5 | +""" |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +import asyncio |
| 10 | +from collections import Counter |
| 11 | +from operator import attrgetter |
| 12 | +from typing import TYPE_CHECKING, Any, Iterable, Iterator |
| 13 | + |
| 14 | +import rich.repr |
| 15 | + |
| 16 | +from .worker import Worker, WorkerState, WorkType |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + from .app import App |
| 20 | + from .dom import DOMNode |
| 21 | + |
| 22 | + |
| 23 | +@rich.repr.auto(angular=True) |
| 24 | +class WorkerManager: |
| 25 | + """An object to manager a number of workers. |
| 26 | +
|
| 27 | + You will not have to construct this class manually, as widgets, screens, and apps |
| 28 | + have a worker manager accessibly via a `workers` attribute. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__(self, app: App) -> None: |
| 32 | + """Initialize a worker manager. |
| 33 | +
|
| 34 | + Args: |
| 35 | + app: An App instance. |
| 36 | + """ |
| 37 | + self._app = app |
| 38 | + """A reference to the app.""" |
| 39 | + self._workers: set[Worker] = set() |
| 40 | + """The workers being managed.""" |
| 41 | + |
| 42 | + def __rich_repr__(self) -> rich.repr.Result: |
| 43 | + counter: Counter[WorkerState] = Counter() |
| 44 | + counter.update(worker.state for worker in self._workers) |
| 45 | + for state, count in sorted(counter.items()): |
| 46 | + yield state.name, count |
| 47 | + |
| 48 | + def __iter__(self) -> Iterator[Worker[Any]]: |
| 49 | + return iter(sorted(self._workers, key=attrgetter("_created_time"))) |
| 50 | + |
| 51 | + def __reversed__(self) -> Iterator[Worker[Any]]: |
| 52 | + return iter( |
| 53 | + sorted(self._workers, key=attrgetter("_created_time"), reverse=True) |
| 54 | + ) |
| 55 | + |
| 56 | + def __bool__(self) -> bool: |
| 57 | + return bool(self._workers) |
| 58 | + |
| 59 | + def __len__(self) -> int: |
| 60 | + return len(self._workers) |
| 61 | + |
| 62 | + def __contains__(self, worker: object) -> bool: |
| 63 | + return worker in self._workers |
| 64 | + |
| 65 | + def add_worker( |
| 66 | + self, worker: Worker, start: bool = True, exclusive: bool = True |
| 67 | + ) -> None: |
| 68 | + """Add a new worker. |
| 69 | +
|
| 70 | + Args: |
| 71 | + worker: A Worker instance. |
| 72 | + start: Start the worker if True, otherwise the worker must be started manually. |
| 73 | + exclusive: Cancel all workers in the same group as `worker`. |
| 74 | + """ |
| 75 | + if exclusive and worker.group: |
| 76 | + self.cancel_group(worker.node, worker.group) |
| 77 | + self._workers.add(worker) |
| 78 | + if start: |
| 79 | + worker._start(self._app, self._remove_worker) |
| 80 | + |
| 81 | + def _new_worker( |
| 82 | + self, |
| 83 | + work: WorkType, |
| 84 | + node: DOMNode, |
| 85 | + *, |
| 86 | + name: str | None = "", |
| 87 | + group: str = "default", |
| 88 | + description: str = "", |
| 89 | + exit_on_error: bool = True, |
| 90 | + start: bool = True, |
| 91 | + exclusive: bool = False, |
| 92 | + thread: bool = False, |
| 93 | + ) -> Worker: |
| 94 | + """Create a worker from a function, coroutine, or awaitable. |
| 95 | +
|
| 96 | + Args: |
| 97 | + work: A callable, a coroutine, or other awaitable. |
| 98 | + name: A name to identify the worker. |
| 99 | + group: The worker group. |
| 100 | + description: A description of the worker. |
| 101 | + exit_on_error: Exit the app if the worker raises an error. Set to `False` to suppress exceptions. |
| 102 | + start: Automatically start the worker. |
| 103 | + exclusive: Cancel all workers in the same group. |
| 104 | + thread: Mark the worker as a thread worker. |
| 105 | +
|
| 106 | + Returns: |
| 107 | + A Worker instance. |
| 108 | + """ |
| 109 | + worker: Worker[Any] = Worker( |
| 110 | + node, |
| 111 | + work, |
| 112 | + name=name or getattr(work, "__name__", "") or "", |
| 113 | + group=group, |
| 114 | + description=description or repr(work), |
| 115 | + exit_on_error=exit_on_error, |
| 116 | + thread=thread, |
| 117 | + ) |
| 118 | + self.add_worker(worker, start=start, exclusive=exclusive) |
| 119 | + return worker |
| 120 | + |
| 121 | + def _remove_worker(self, worker: Worker) -> None: |
| 122 | + """Remove a worker from the manager. |
| 123 | +
|
| 124 | + Args: |
| 125 | + worker: A Worker instance. |
| 126 | + """ |
| 127 | + self._workers.discard(worker) |
| 128 | + |
| 129 | + def start_all(self) -> None: |
| 130 | + """Start all the workers.""" |
| 131 | + for worker in self._workers: |
| 132 | + worker._start(self._app, self._remove_worker) |
| 133 | + |
| 134 | + def cancel_all(self) -> None: |
| 135 | + """Cancel all workers.""" |
| 136 | + for worker in self._workers: |
| 137 | + worker.cancel() |
| 138 | + |
| 139 | + def cancel_group(self, node: DOMNode, group: str) -> list[Worker]: |
| 140 | + """Cancel a single group. |
| 141 | +
|
| 142 | + Args: |
| 143 | + node: Worker DOM node. |
| 144 | + group: A group name. |
| 145 | +
|
| 146 | + Returns: |
| 147 | + A list of workers that were cancelled. |
| 148 | + """ |
| 149 | + workers = [ |
| 150 | + worker |
| 151 | + for worker in self._workers |
| 152 | + if (worker.group == group and worker.node == node) |
| 153 | + ] |
| 154 | + for worker in workers: |
| 155 | + worker.cancel() |
| 156 | + return workers |
| 157 | + |
| 158 | + def cancel_node(self, node: DOMNode) -> list[Worker]: |
| 159 | + """Cancel all workers associated with a given node |
| 160 | +
|
| 161 | + Args: |
| 162 | + node: A DOM node (widget, screen, or App). |
| 163 | +
|
| 164 | + Returns: |
| 165 | + List of cancelled workers. |
| 166 | + """ |
| 167 | + workers = [worker for worker in self._workers if worker.node == node] |
| 168 | + for worker in workers: |
| 169 | + worker.cancel() |
| 170 | + return workers |
| 171 | + |
| 172 | + async def wait_for_complete(self, workers: Iterable[Worker] | None = None) -> None: |
| 173 | + """Wait for workers to complete. |
| 174 | +
|
| 175 | + Args: |
| 176 | + workers: An iterable of workers or None to wait for all workers in the manager. |
| 177 | + """ |
| 178 | + try: |
| 179 | + await asyncio.gather(*[worker.wait() for worker in (workers or self)]) |
| 180 | + except asyncio.CancelledError: |
| 181 | + pass |
0 commit comments