Skip to content

Commit cabfe8d

Browse files
committed
add multiprocessing functionality for sync api
1 parent 1afc821 commit cabfe8d

File tree

5 files changed

+150
-123
lines changed

5 files changed

+150
-123
lines changed

src/pyper/_core/sync_helper/output.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .stage import Producer, ProducerConsumer
77
from ..util.sentinel import StopSentinel
8-
from ..util.thread_pool import ThreadPool
8+
from ..util.worker_pool import ProcessPool, ThreadPool
99

1010
if TYPE_CHECKING:
1111
from ..pipeline import Pipeline
@@ -15,32 +15,33 @@ class PipelineOutput:
1515
def __init__(self, pipeline: Pipeline):
1616
self.pipeline = pipeline
1717

18-
def _get_q_out(self, tp: ThreadPool, *args, **kwargs) -> queue.Queue:
18+
def _get_q_out(self, tp: ThreadPool, pp: ProcessPool, *args, **kwargs) -> queue.Queue:
1919
"""Feed forward each stage to the next, returning the output queue of the final stage."""
2020
q_out = None
2121
for task, next_task in zip(self.pipeline.tasks, self.pipeline.tasks[1:] + [None]):
22-
n_consumers = 1 if next_task is None else next_task.concurrency
22+
pool = pp if task.multiprocess else tp
2323
if q_out is None:
24-
stage = Producer(task=self.pipeline.tasks[0], tp=tp, n_consumers=n_consumers)
25-
stage.start(*args, **kwargs)
24+
stage = Producer(task=task, next_task=next_task, q_err=pool.error_queue)
25+
stage.start(pool, *args, **kwargs)
2626
else:
27-
stage = ProducerConsumer(q_in=q_out, task=task, tp=tp, n_consumers=n_consumers)
28-
stage.start()
27+
stage = ProducerConsumer(q_in=q_out, task=task, next_task=next_task, q_err=pool.error_queue)
28+
stage.start(pool)
2929
q_out = stage.q_out
3030

3131
return q_out
3232

3333
def __call__(self, *args, **kwargs):
34-
"""Call the pipeline, taking the inputs to the first task, and returning the output from the last task."""
35-
with ThreadPool() as tp:
36-
q_out = self._get_q_out(tp, *args, **kwargs)
34+
"""Iterate through the pipeline, taking the inputs to the first task, and yielding each output from the last task."""
35+
with ThreadPool() as tp, ProcessPool() as pp:
36+
q_out = self._get_q_out(tp, pp, *args, **kwargs)
3737
while True:
38-
tp.raise_error_if_exists()
3938
try:
4039
# Use the timeout strategy for unblocking main thread without busy waiting
4140
if (data := q_out.get(timeout=0.1)) is StopSentinel:
4241
tp.raise_error_if_exists()
42+
pp.raise_error_if_exists()
4343
break
4444
yield data
4545
except queue.Empty:
46-
continue
46+
tp.raise_error_if_exists()
47+
pp.raise_error_if_exists()
Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,24 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Union
44

55
from ..util.sentinel import StopSentinel
66

77
if TYPE_CHECKING:
8+
import multiprocessing as mp
89
import queue
910
from ..task import Task
1011

1112

12-
class Dequeue:
13+
def DequeueFactory(q_in: Union[mp.Queue, queue.Queue], task: Task):
14+
return _JoiningDequeue(q_in=q_in) if task.join \
15+
else _SingleDequeue(q_in=q_in)
16+
17+
18+
class _Dequeue:
1319
"""Pulls data from an input queue."""
14-
def __new__(self, q_in: queue.Queue, task: Task):
15-
if task.join:
16-
instance = object.__new__(_JoiningDequeue)
17-
else:
18-
instance = object.__new__(_SingleDequeue)
19-
instance.__init__(q_in=q_in, task=task)
20-
return instance
21-
22-
def __init__(self, q_in: queue.Queue, task: Task):
23-
self.q_in = q_in
24-
self.task = task
20+
def __init__(self, q_in: Union[mp.Queue, queue.Queue]):
21+
self.q_in = q_in
2522

2623
def _input_stream(self):
2724
while (data := self.q_in.get()) is not StopSentinel:
@@ -31,41 +28,38 @@ def __call__(self):
3128
raise NotImplementedError
3229

3330

34-
class _SingleDequeue(Dequeue):
31+
class _SingleDequeue(_Dequeue):
3532
def __call__(self):
3633
for data in self._input_stream():
3734
yield data
3835

3936

40-
class _JoiningDequeue(Dequeue):
37+
class _JoiningDequeue(_Dequeue):
4138
def __call__(self):
4239
yield self._input_stream()
4340

4441

45-
class Enqueue:
42+
def EnqueueFactory(q_out: Union[mp.Queue, queue.Queue], task: Task):
43+
return _BranchingEnqueue(q_out=q_out, task=task) if task.is_gen \
44+
else _SingleEnqueue(q_out=q_out, task=task)
45+
46+
47+
class _Enqueue:
4648
"""Puts output from a task onto an output queue."""
47-
def __new__(cls, q_out: queue.Queue, task: Task):
48-
if task.is_gen:
49-
instance = object.__new__(_BranchingEnqueue)
50-
else:
51-
instance = object.__new__(_SingleEnqueue)
52-
instance.__init__(q_out=q_out, task=task)
53-
return instance
54-
55-
def __init__(self, q_out: queue.Queue, task: Task):
56-
self.q_out = q_out
57-
self.task = task
49+
def __init__(self, q_out: Union[mp.Queue, queue.Queue], task: Task):
50+
self.q_out = q_out
51+
self.task = task
5852

5953
def __call__(self, *args, **kwargs):
6054
raise NotImplementedError
6155

6256

63-
class _SingleEnqueue(Enqueue):
57+
class _SingleEnqueue(_Enqueue):
6458
def __call__(self, *args, **kwargs):
6559
self.q_out.put(self.task.func(*args, **kwargs))
6660

6761

68-
class _BranchingEnqueue(Enqueue):
62+
class _BranchingEnqueue(_Enqueue):
6963
def __call__(self, *args, **kwargs):
70-
for output in self.task.func(*args, **kwargs):
64+
for output in self.task.func(*args, **kwargs):
7165
self.q_out.put(output)
Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,91 @@
11
from __future__ import annotations
22

3+
import multiprocessing as mp
34
import queue
45
import threading
5-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Union
67

7-
from .queue_io import Dequeue, Enqueue
8+
from .queue_io import DequeueFactory, EnqueueFactory
89
from ..util.sentinel import StopSentinel
910

1011
if TYPE_CHECKING:
11-
from ..util.thread_pool import ThreadPool
12+
from ..util.worker_pool import WorkerPool
1213
from ..task import Task
1314

1415

1516
class Producer:
16-
def __init__(self, task: Task, tp: ThreadPool, n_consumers: int):
17-
self.task = task
17+
def __init__(self, task: Task, next_task: Task, q_err: Union[mp.Queue, queue.Queue]):
1818
if task.concurrency > 1:
1919
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have concurrency greater than 1")
2020
if task.join:
2121
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot join previous results")
22-
self.tp = tp
23-
self.n_consumers = n_consumers
24-
self.q_out = queue.Queue(maxsize=task.throttle)
22+
self.q_out = mp.Queue(maxsize=task.throttle) \
23+
if task.multiprocess or (next_task is not None and next_task.multiprocess) \
24+
else queue.Queue(maxsize=task.throttle)
2525

26-
self._enqueue = Enqueue(self.q_out, self.task)
26+
self._q_err = q_err
27+
self._n_workers = task.concurrency
28+
self._n_consumers = 1 if next_task is None else next_task.concurrency
29+
self._enqueue = EnqueueFactory(self.q_out, task)
30+
31+
self._multiprocess = task.multiprocess
2732

2833
def _worker(self, *args, **kwargs):
2934
try:
3035
self._enqueue(*args, **kwargs)
3136
except Exception as e:
32-
self.tp.put_error(e)
37+
self._q_err.put(e)
3338
finally:
34-
for _ in range(self.n_consumers):
39+
for _ in range(self._n_consumers):
3540
self.q_out.put(StopSentinel)
3641

37-
def start(self, *args, **kwargs):
38-
self.tp.submit(self._worker, args, kwargs, daemon=self.task.daemon)
42+
def start(self, pool: WorkerPool, /, *args, **kwargs):
43+
pool.submit(self._worker, *args, **kwargs)
3944

4045

4146
class ProducerConsumer:
42-
def __init__(self, q_in: queue.Queue, task: Task, tp: ThreadPool, n_consumers: int):
43-
self.q_in = q_in
44-
self.task = task
45-
self.tp = tp
46-
self.n_consumers = n_consumers
47-
self.q_out = queue.Queue(maxsize=task.throttle)
47+
def __init__(self, q_in: Union[mp.Queue, queue.Queue], task: Task, next_task: Task, q_err: Union[mp.Queue, queue.Queue]):
48+
self.q_out = mp.Queue(maxsize=task.throttle) \
49+
if task.multiprocess or (next_task is not None and next_task.multiprocess) \
50+
else queue.Queue(maxsize=task.throttle)
51+
52+
self._q_err = q_err
53+
self._n_workers = task.concurrency
54+
self._n_consumers = 1 if next_task is None else next_task.concurrency
55+
self._dequeue = DequeueFactory(q_in, task)
56+
self._enqueue = EnqueueFactory(self.q_out, task)
57+
58+
self._multiprocess = task.multiprocess
59+
if self._multiprocess:
60+
self._workers_done = mp.Value('i', 0)
61+
self._lock = self._workers_done.get_lock()
62+
else:
63+
self._workers_done = 0
64+
self._lock = threading.Lock()
65+
66+
def _increment_workers_done(self):
67+
with self._lock:
68+
if self._multiprocess:
69+
self._workers_done.value += 1
70+
return self._workers_done.value
71+
else:
72+
self._workers_done += 1
73+
return self._workers_done
74+
75+
def _finish(self):
76+
if self._increment_workers_done() == self._n_workers:
77+
for _ in range(self._n_consumers):
78+
self.q_out.put(StopSentinel)
4879

49-
self._workers_done = 0
50-
self._workers_done_lock = threading.Lock()
51-
self._dequeue = Dequeue(self.q_in, self.task)
52-
self._enqueue = Enqueue(self.q_out, self.task)
53-
5480
def _worker(self):
5581
try:
5682
for output in self._dequeue():
5783
self._enqueue(output)
5884
except Exception as e:
59-
self.tp.put_error(e)
85+
self._q_err.put(e)
6086
finally:
61-
with self._workers_done_lock:
62-
self._workers_done += 1
63-
if self._workers_done == self.task.concurrency:
64-
for _ in range(self.n_consumers):
65-
self.q_out.put(StopSentinel)
66-
67-
def start(self):
68-
for _ in range(self.task.concurrency):
69-
self.tp.submit(self._worker, daemon=self.task.daemon)
87+
self._finish()
88+
89+
def start(self, pool: WorkerPool, /):
90+
for _ in range(self._n_workers):
91+
pool.submit(self._worker)

src/pyper/_core/util/thread_pool.py

Lines changed: 0 additions & 47 deletions
This file was deleted.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from __future__ import annotations
2+
3+
import multiprocessing as mp
4+
import queue
5+
import threading
6+
from typing import List, Union
7+
8+
9+
class WorkerPool:
10+
"""A context for fine-grained thread/process management and error handling.
11+
12+
1. Spins up thread/process workers and maintains a reference to each
13+
2. Provides a mechanism to capture and propagate errors to the main thread/process
14+
3. Ensures safe tear-down of all workers
15+
"""
16+
worker_type: type = None
17+
18+
error_queue: Union[mp.Queue, queue.Queue]
19+
_workers: List[Union[mp.Process, threading.Thread]]
20+
21+
def __init__(self):
22+
self.error_queue = mp.Queue(1) if self.worker_type is mp.Process else queue.Queue(1)
23+
self._workers = []
24+
25+
def __enter__(self):
26+
return self
27+
28+
def __exit__(self, et, ev, tb):
29+
for worker in self._workers:
30+
worker.join()
31+
32+
@property
33+
def has_error(self):
34+
return not self.error_queue.empty()
35+
36+
def get_error(self) -> Exception:
37+
return self.error_queue.get()
38+
39+
def put_error(self, e: Exception):
40+
self.error_queue.put(e)
41+
42+
def raise_error_if_exists(self):
43+
if self.has_error:
44+
raise self.get_error() from None
45+
46+
def submit(self, func, /, *args, **kwargs):
47+
w = self.worker_type(target=func, args=args, kwargs=kwargs)
48+
w.start()
49+
self._workers.append(w)
50+
51+
52+
class ThreadPool(WorkerPool):
53+
worker_type = threading.Thread
54+
55+
56+
class ProcessPool(WorkerPool):
57+
worker_type = mp.Process

0 commit comments

Comments
 (0)