Skip to content

Commit 5de1e3b

Browse files
committed
tentative sync refactor to poolexecutors
1 parent e7e8edb commit 5de1e3b

File tree

5 files changed

+78
-82
lines changed

5 files changed

+78
-82
lines changed

src/pyper/_core/sync_helper/output.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ def _get_q_out(self, tp: ThreadPool, pp: ProcessPool, *args, **kwargs) -> queue.
2121
for task, next_task in zip(self.pipeline.tasks, self.pipeline.tasks[1:] + [None]):
2222
pool = pp if task.multiprocess else tp
2323
if q_out is None:
24-
stage = Producer(task=task, next_task=next_task, q_err=pool.error_queue, shutdown_event=pool.shutdown_event)
24+
stage = Producer(task=task, next_task=next_task, manager=pp.manager, shutdown_event=pool.shutdown_event)
2525
stage.start(pool, *args, **kwargs)
2626
else:
27-
stage = ProducerConsumer(q_in=q_out, task=task, next_task=next_task, q_err=pool.error_queue, shutdown_event=pool.shutdown_event)
27+
stage = ProducerConsumer(q_in=q_out, task=task, next_task=next_task, manager=pp.manager, shutdown_event=pool.shutdown_event)
2828
stage.start(pool)
2929
q_out = stage.q_out
3030

@@ -37,15 +37,8 @@ def __call__(self, *args, **kwargs):
3737
while True:
3838
try:
3939
# Use the timeout strategy for unblocking main thread without busy waiting
40-
if (data := q_out.get(timeout=0.1)) is StopSentinel:
41-
tp.raise_error_if_exists()
42-
pp.raise_error_if_exists()
40+
if (data := q_out.get()) is StopSentinel:
4341
break
4442
yield data
45-
except queue.Empty:
46-
tp.raise_error_if_exists()
47-
pp.raise_error_if_exists()
4843
except (KeyboardInterrupt, SystemExit): # pragma: no cover
49-
tp.shutdown_event.set()
50-
pp.shutdown_event.set()
5144
raise

src/pyper/_core/sync_helper/stage.py

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from __future__ import annotations
22

3-
import multiprocessing as mp
43
import queue
4+
import threading
55
from typing import TYPE_CHECKING, Union
66

77
from .queue_io import DequeueFactory, EnqueueFactory
88
from ..util.sentinel import StopSentinel
9-
from ..util.value import ThreadingValue
9+
from ..util.counter import MultiprocessingCounter, ThreadingCounter
1010

1111
if TYPE_CHECKING:
12-
from multiprocessing.synchronize import Event as ProcessEvent
13-
from threading import Event as ThreadEvent
12+
from multiprocessing.managers import SyncManager
13+
import multiprocessing.queues as mpq
14+
import multiprocessing.synchronize as mpsync
1415
from ..util.worker_pool import WorkerPool
1516
from ..task import Task
1617

@@ -20,17 +21,16 @@ def __init__(
2021
self,
2122
task: Task,
2223
next_task: Task,
23-
q_err: Union[mp.Queue, queue.Queue],
24-
shutdown_event: Union[ProcessEvent, ThreadEvent]):
24+
manager: SyncManager,
25+
shutdown_event: Union[mpsync.Event, threading.Event]):
2526
if task.workers > 1:
2627
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have more than 1 worker")
2728
if task.join:
2829
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot join previous results")
29-
self.q_out = mp.Queue(maxsize=task.throttle) \
30+
self.q_out = manager.Queue(maxsize=task.throttle) \
3031
if task.multiprocess or (next_task is not None and next_task.multiprocess) \
3132
else queue.Queue(maxsize=task.throttle)
3233

33-
self._q_err = q_err
3434
self._shutdown_event = shutdown_event
3535
self._n_workers = task.workers
3636
self._n_consumers = 1 if next_task is None else next_task.workers
@@ -39,9 +39,9 @@ def __init__(
3939
def _worker(self, *args, **kwargs):
4040
try:
4141
self._enqueue(*args, **kwargs)
42-
except Exception as e:
42+
except Exception:
4343
self._shutdown_event.set()
44-
self._q_err.put(e)
44+
raise
4545
finally:
4646
for _ in range(self._n_consumers):
4747
self.q_out.put(StopSentinel)
@@ -53,45 +53,37 @@ def start(self, pool: WorkerPool, /, *args, **kwargs):
5353
class ProducerConsumer:
5454
def __init__(
5555
self,
56-
q_in: Union[mp.Queue, queue.Queue],
56+
q_in: Union[mpq.Queue, queue.Queue],
5757
task: Task,
5858
next_task: Task,
59-
q_err: Union[mp.Queue, queue.Queue],
60-
shutdown_event: Union[ProcessEvent, ThreadEvent]):
59+
manager: SyncManager,
60+
shutdown_event: Union[mpsync.Event, threading.Event]):
6161
# The output queue is shared between this task and the next. We optimize here by using queue.Queue wherever possible
62-
# and only using multiprocess.Queue when the current task or the next task are multiprocessed
63-
self.q_out = mp.Queue(maxsize=task.throttle) \
62+
# and only using a multiprocess Queue when the current task or the next task are multiprocessed
63+
self.q_out = manager.Queue(maxsize=task.throttle) \
6464
if task.multiprocess or (next_task is not None and next_task.multiprocess) \
6565
else queue.Queue(maxsize=task.throttle)
6666

67-
self._q_err = q_err
6867
self._shutdown_event = shutdown_event
6968
self._n_workers = task.workers
7069
self._n_consumers = 1 if next_task is None else next_task.workers
7170
self._dequeue = DequeueFactory(q_in, task)
7271
self._enqueue = EnqueueFactory(self.q_out, task)
73-
self._workers_done = mp.Value('i', 0) if task.multiprocess else ThreadingValue(0)
74-
75-
def _increment_workers_done(self):
76-
with self._workers_done.get_lock():
77-
self._workers_done.value += 1
78-
return self._workers_done.value
79-
80-
def _finish(self):
81-
if self._increment_workers_done() == self._n_workers:
82-
for _ in range(self._n_consumers):
83-
self.q_out.put(StopSentinel)
72+
self._workers_done = MultiprocessingCounter(0, manager) if task.multiprocess else ThreadingCounter(0)
8473

8574
def _worker(self):
8675
try:
8776
for output in self._dequeue():
8877
if not self._shutdown_event.is_set():
8978
self._enqueue(output)
90-
except Exception as e:
79+
except Exception:
9180
self._shutdown_event.set()
92-
self._q_err.put(e)
81+
raise
9382
finally:
94-
self._finish()
83+
if self._workers_done.increment() == self._n_workers:
84+
for _ in range(self._n_consumers):
85+
self.q_out.put(StopSentinel)
86+
9587

9688
def start(self, pool: WorkerPool, /):
9789
for _ in range(self._n_workers):

src/pyper/_core/util/counter.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import annotations
2+
3+
from multiprocessing.managers import SyncManager
4+
import threading
5+
6+
7+
class MultiprocessingCounter:
8+
"""Utility class to manage process based access to an integer."""
9+
def __init__(self, value: int, manager: SyncManager):
10+
self._manager_value = manager.Value('i', value)
11+
self._lock = manager.Lock()
12+
13+
def increment(self):
14+
with self._lock:
15+
self._manager_value.value += 1
16+
return self._manager_value.value
17+
18+
19+
class ThreadingCounter:
20+
"""Utility class to manage thread based access to an integer."""
21+
def __init__(self, value: int):
22+
self._value = value
23+
self._lock = threading.Lock()
24+
25+
def increment(self):
26+
with self._lock:
27+
self._value += 1
28+
return self._value

src/pyper/_core/util/value.py

Lines changed: 0 additions & 17 deletions
This file was deleted.
Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
import concurrent.futures as cf
34
import multiprocessing as mp
4-
import queue
5+
import multiprocessing.synchronize as mpsync
56
import threading
67
from typing import List, Union
78

@@ -13,41 +14,40 @@ class WorkerPool:
1314
2. Provides a mechanism to capture and propagate errors to the main thread/process
1415
3. Ensures safe tear-down of all workers
1516
"""
16-
worker_type = None
17-
18-
def __init__(self):
19-
self.error_queue = mp.Queue() if self.worker_type is mp.Process else queue.Queue()
20-
self.shutdown_event = mp.Event() if self.worker_type is mp.Process else threading.Event()
21-
22-
self._workers: List[Union[mp.Process, threading.Thread]] = []
17+
shutdown_event: Union[mpsync.Event, threading.Event]
18+
_executor: Union[cf.ProcessPoolExecutor, cf.ThreadPoolExecutor]
19+
_futures: List[cf.Future]
2320

2421
def __enter__(self):
22+
self._executor.__enter__()
2523
return self
2624

2725
def __exit__(self, et, ev, tb):
28-
for worker in self._workers:
29-
worker.join()
30-
31-
@property
32-
def has_error(self):
33-
return not self.error_queue.empty()
34-
35-
def get_error(self) -> Exception:
36-
return self.error_queue.get()
26+
self._executor.__exit__(et, ev, tb)
27+
for future in self._futures:
28+
future.result()
3729

38-
def raise_error_if_exists(self):
39-
if self.has_error:
40-
raise self.get_error() from None
30+
def create_queue(self):
31+
raise NotImplementedError
4132

4233
def submit(self, func, /, *args, **kwargs):
43-
w = self.worker_type(target=func, args=args, kwargs=kwargs)
44-
w.start()
45-
self._workers.append(w)
34+
future = self._executor.submit(func, *args, **kwargs)
35+
self._futures.append(future)
36+
return future
4637

4738

4839
class ThreadPool(WorkerPool):
49-
worker_type = threading.Thread
40+
def __init__(self):
41+
self.shutdown_event = threading.Event()
42+
43+
self._executor = cf.ThreadPoolExecutor()
44+
self._futures = []
5045

5146

5247
class ProcessPool(WorkerPool):
53-
worker_type = mp.Process
48+
def __init__(self):
49+
self.manager = mp.Manager()
50+
self.shutdown_event = self.manager.Event()
51+
52+
self._executor = cf.ProcessPoolExecutor()
53+
self._futures = []

0 commit comments

Comments
 (0)