|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import multiprocessing as mp |
3 | 4 | import queue |
4 | 5 | import threading |
5 | | -from typing import TYPE_CHECKING |
| 6 | +from typing import TYPE_CHECKING, Union |
6 | 7 |
|
7 | | -from .queue_io import Dequeue, Enqueue |
| 8 | +from .queue_io import DequeueFactory, EnqueueFactory |
8 | 9 | from ..util.sentinel import StopSentinel |
9 | 10 |
|
10 | 11 | if TYPE_CHECKING: |
11 | | - from ..util.thread_pool import ThreadPool |
| 12 | + from ..util.worker_pool import WorkerPool |
12 | 13 | from ..task import Task |
13 | 14 |
|
14 | 15 |
|
15 | 16 | class Producer: |
16 | | - def __init__(self, task: Task, tp: ThreadPool, n_consumers: int): |
17 | | - self.task = task |
| 17 | + def __init__(self, task: Task, next_task: Task, q_err: Union[mp.Queue, queue.Queue]): |
18 | 18 | if task.concurrency > 1: |
19 | 19 | raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have concurrency greater than 1") |
20 | 20 | if task.join: |
21 | 21 | raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot join previous results") |
22 | | - self.tp = tp |
23 | | - self.n_consumers = n_consumers |
24 | | - self.q_out = queue.Queue(maxsize=task.throttle) |
| 22 | + self.q_out = mp.Queue(maxsize=task.throttle) \ |
| 23 | + if task.multiprocess or (next_task is not None and next_task.multiprocess) \ |
| 24 | + else queue.Queue(maxsize=task.throttle) |
25 | 25 |
|
26 | | - self._enqueue = Enqueue(self.q_out, self.task) |
| 26 | + self._q_err = q_err |
| 27 | + self._n_workers = task.concurrency |
| 28 | + self._n_consumers = 1 if next_task is None else next_task.concurrency |
| 29 | + self._enqueue = EnqueueFactory(self.q_out, task) |
| 30 | + |
| 31 | + self._multiprocess = task.multiprocess |
27 | 32 |
|
28 | 33 | def _worker(self, *args, **kwargs): |
29 | 34 | try: |
30 | 35 | self._enqueue(*args, **kwargs) |
31 | 36 | except Exception as e: |
32 | | - self.tp.put_error(e) |
| 37 | + self._q_err.put(e) |
33 | 38 | finally: |
34 | | - for _ in range(self.n_consumers): |
| 39 | + for _ in range(self._n_consumers): |
35 | 40 | self.q_out.put(StopSentinel) |
36 | 41 |
|
37 | | - def start(self, *args, **kwargs): |
38 | | - self.tp.submit(self._worker, args, kwargs, daemon=self.task.daemon) |
| 42 | + def start(self, pool: WorkerPool, /, *args, **kwargs): |
| 43 | + pool.submit(self._worker, *args, **kwargs) |
39 | 44 |
|
40 | 45 |
|
41 | 46 | class ProducerConsumer: |
42 | | - def __init__(self, q_in: queue.Queue, task: Task, tp: ThreadPool, n_consumers: int): |
43 | | - self.q_in = q_in |
44 | | - self.task = task |
45 | | - self.tp = tp |
46 | | - self.n_consumers = n_consumers |
47 | | - self.q_out = queue.Queue(maxsize=task.throttle) |
| 47 | + def __init__(self, q_in: Union[mp.Queue, queue.Queue], task: Task, next_task: Task, q_err: Union[mp.Queue, queue.Queue]): |
| 48 | + self.q_out = mp.Queue(maxsize=task.throttle) \ |
| 49 | + if task.multiprocess or (next_task is not None and next_task.multiprocess) \ |
| 50 | + else queue.Queue(maxsize=task.throttle) |
| 51 | + |
| 52 | + self._q_err = q_err |
| 53 | + self._n_workers = task.concurrency |
| 54 | + self._n_consumers = 1 if next_task is None else next_task.concurrency |
| 55 | + self._dequeue = DequeueFactory(q_in, task) |
| 56 | + self._enqueue = EnqueueFactory(self.q_out, task) |
| 57 | + |
| 58 | + self._multiprocess = task.multiprocess |
| 59 | + if self._multiprocess: |
| 60 | + self._workers_done = mp.Value('i', 0) |
| 61 | + self._lock = self._workers_done.get_lock() |
| 62 | + else: |
| 63 | + self._workers_done = 0 |
| 64 | + self._lock = threading.Lock() |
| 65 | + |
| 66 | + def _increment_workers_done(self): |
| 67 | + with self._lock: |
| 68 | + if self._multiprocess: |
| 69 | + self._workers_done.value += 1 |
| 70 | + return self._workers_done.value |
| 71 | + else: |
| 72 | + self._workers_done += 1 |
| 73 | + return self._workers_done |
| 74 | + |
| 75 | + def _finish(self): |
| 76 | + if self._increment_workers_done() == self._n_workers: |
| 77 | + for _ in range(self._n_consumers): |
| 78 | + self.q_out.put(StopSentinel) |
48 | 79 |
|
49 | | - self._workers_done = 0 |
50 | | - self._workers_done_lock = threading.Lock() |
51 | | - self._dequeue = Dequeue(self.q_in, self.task) |
52 | | - self._enqueue = Enqueue(self.q_out, self.task) |
53 | | - |
54 | 80 | def _worker(self): |
55 | 81 | try: |
56 | 82 | for output in self._dequeue(): |
57 | 83 | self._enqueue(output) |
58 | 84 | except Exception as e: |
59 | | - self.tp.put_error(e) |
| 85 | + self._q_err.put(e) |
60 | 86 | finally: |
61 | | - with self._workers_done_lock: |
62 | | - self._workers_done += 1 |
63 | | - if self._workers_done == self.task.concurrency: |
64 | | - for _ in range(self.n_consumers): |
65 | | - self.q_out.put(StopSentinel) |
66 | | - |
67 | | - def start(self): |
68 | | - for _ in range(self.task.concurrency): |
69 | | - self.tp.submit(self._worker, daemon=self.task.daemon) |
| 87 | + self._finish() |
| 88 | + |
| 89 | + def start(self, pool: WorkerPool, /): |
| 90 | + for _ in range(self._n_workers): |
| 91 | + pool.submit(self._worker) |
0 commit comments