|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import asyncio |
4 | | -from concurrent.futures import ProcessPoolExecutor |
5 | 4 | import sys |
6 | 5 | from typing import TYPE_CHECKING |
7 | 6 |
|
8 | | -from .queue_io import AsyncDequeue, AsyncEnqueue |
9 | | -from ..util.asynchronize import ascynchronize |
| 7 | +from .queue_io import AsyncDequeueFactory, AsyncEnqueueFactory |
10 | 8 | from ..util.sentinel import StopSentinel |
11 | 9 |
|
12 | 10 | if sys.version_info < (3, 11): # pragma: no cover |
|
15 | 13 | from asyncio import TaskGroup |
16 | 14 |
|
17 | 15 | if TYPE_CHECKING: |
18 | | - from ..util.thread_pool import ThreadPool |
19 | 16 | from ..task import Task |
20 | 17 |
|
21 | 18 |
|
22 | 19 | class AsyncProducer: |
23 | | - def __init__( |
24 | | - self, |
25 | | - task: Task, |
26 | | - tg: TaskGroup, |
27 | | - tp: ThreadPool, |
28 | | - pp: ProcessPoolExecutor, |
29 | | - n_consumers: int): |
30 | | - self.task = ascynchronize(task, tp, pp) |
| 20 | + def __init__(self, task: Task, next_task: Task): |
31 | 21 | if task.concurrency > 1: |
32 | 22 | raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have concurrency greater than 1") |
33 | 23 | if task.join: |
34 | 24 | raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot join previous results") |
35 | | - self.tg = tg |
36 | | - self.n_consumers = n_consumers |
| 25 | + self.task = task |
37 | 26 | self.q_out = asyncio.Queue(maxsize=task.throttle) |
38 | 27 |
|
39 | | - self._enqueue = AsyncEnqueue(self.q_out, self.task) |
| 28 | + self._n_consumers = 1 if next_task is None else next_task.concurrency |
| 29 | + self._enqueue = AsyncEnqueueFactory(self.q_out, self.task) |
40 | 30 |
|
41 | 31 | async def _worker(self, *args, **kwargs): |
42 | 32 | await self._enqueue(*args, **kwargs) |
43 | 33 |
|
44 | | - for _ in range(self.n_consumers): |
| 34 | + for _ in range(self._n_consumers): |
45 | 35 | await self.q_out.put(StopSentinel) |
46 | 36 |
|
47 | | - def start(self, *args, **kwargs): |
48 | | - self.tg.create_task(self._worker(*args, **kwargs)) |
| 37 | + def start(self, tg: TaskGroup, /, *args, **kwargs): |
| 38 | + tg.create_task(self._worker(*args, **kwargs)) |
49 | 39 |
|
50 | 40 |
|
51 | 41 | class AsyncProducerConsumer: |
52 | | - def __init__( |
53 | | - self, |
54 | | - q_in: asyncio.Queue, |
55 | | - task: Task, |
56 | | - tg: TaskGroup, |
57 | | - tp: ThreadPool, |
58 | | - pp: ProcessPoolExecutor, |
59 | | - n_consumers: int): |
60 | | - self.q_in = q_in |
61 | | - self.task = ascynchronize(task, tp, pp) |
62 | | - self.tg = tg |
63 | | - self.n_consumers = n_consumers |
| 42 | + def __init__(self, q_in: asyncio.Queue, task: Task, next_task: Task): |
64 | 43 | self.q_out = asyncio.Queue(maxsize=task.throttle) |
65 | 44 |
|
| 45 | + self._n_workers = task.concurrency |
| 46 | + self._n_consumers = 1 if next_task is None else next_task.concurrency |
| 47 | + self._dequeue = AsyncDequeueFactory(q_in, task) |
| 48 | + self._enqueue = AsyncEnqueueFactory(self.q_out, task) |
66 | 49 | self._workers_done = 0 |
67 | | - self._dequeue = AsyncDequeue(self.q_in, self.task) |
68 | | - self._enqueue = AsyncEnqueue(self.q_out, self.task) |
69 | 50 |
|
70 | 51 | async def _worker(self): |
71 | 52 | async for output in self._dequeue(): |
72 | 53 | await self._enqueue(output) |
73 | 54 |
|
74 | 55 | self._workers_done += 1 |
75 | | - if self._workers_done == self.task.concurrency: |
76 | | - for _ in range(self.n_consumers): |
| 56 | + if self._workers_done == self._n_workers: |
| 57 | + for _ in range(self._n_consumers): |
77 | 58 | await self.q_out.put(StopSentinel) |
78 | 59 |
|
79 | | - def start(self): |
80 | | - for _ in range(self.task.concurrency): |
81 | | - self.tg.create_task(self._worker()) |
| 60 | + def start(self, tg: TaskGroup, /): |
| 61 | + for _ in range(self._n_workers): |
| 62 | + tg.create_task(self._worker()) |
0 commit comments