Skip to content

Commit 8e95ae1

Browse files
committed
tidy up misc code
1 parent 5de1e3b commit 8e95ae1

File tree

6 files changed

+19
-54
lines changed

6 files changed

+19
-54
lines changed

src/pyper/_core/async_helper/stage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
class AsyncProducer:
2020
def __init__(self, task: Task, next_task: Task):
2121
if task.workers > 1:
22-
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have more than 1 worker")
22+
raise RuntimeError(f"The first task in a pipeline ({task.func}) cannot have more than 1 worker")
2323
if task.join:
24-
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot join previous results")
24+
raise RuntimeError(f"The first task in a pipeline ({task.func}) cannot join previous results")
2525
self.task = task
2626
self.q_out = asyncio.Queue(maxsize=task.throttle)
2727

src/pyper/_core/sync_helper/output.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,5 @@ def __call__(self, *args, **kwargs):
3434
"""Iterate through the pipeline, taking the inputs to the first task, and yielding each output from the last task."""
3535
with ThreadPool() as tp, ProcessPool() as pp:
3636
q_out = self._get_q_out(tp, pp, *args, **kwargs)
37-
while True:
38-
try:
39-
# Use the timeout strategy for unblocking main thread without busy waiting
40-
if (data := q_out.get()) is StopSentinel:
41-
break
42-
yield data
43-
except (KeyboardInterrupt, SystemExit): # pragma: no cover
44-
raise
37+
while (data := q_out.get()) is not StopSentinel:
38+
yield data

src/pyper/_core/sync_helper/queue_io.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ def __call__(self):
3030

3131

3232
class _SingleDequeue(_Dequeue):
33-
def __call__(self):
34-
for data in self._input_stream():
33+
def __call__(self):
34+
for data in self._input_stream():
3535
yield data
3636

3737

3838
class _JoiningDequeue(_Dequeue):
39-
def __call__(self):
39+
def __call__(self):
4040
yield self._input_stream()
4141

4242

@@ -56,12 +56,12 @@ def __call__(self, *args, **kwargs):
5656

5757

5858
class _SingleEnqueue(_Enqueue):
59-
def __call__(self, *args, **kwargs):
59+
def __call__(self, *args, **kwargs):
6060
self.q_out.put(self.task.func(*args, **kwargs))
6161

6262

6363
class _BranchingEnqueue(_Enqueue):
64-
def __call__(self, *args, **kwargs):
64+
def __call__(self, *args, **kwargs):
6565
if isinstance(result := self.task.func(*args, **kwargs), Iterable):
6666
for output in result:
6767
self.q_out.put(output)

src/pyper/_core/sync_helper/stage.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import queue
44
import threading
5+
from types import SimpleNamespace
56
from typing import TYPE_CHECKING, Union
67

78
from .queue_io import DequeueFactory, EnqueueFactory
89
from ..util.sentinel import StopSentinel
9-
from ..util.counter import MultiprocessingCounter, ThreadingCounter
1010

1111
if TYPE_CHECKING:
1212
from multiprocessing.managers import SyncManager
@@ -24,9 +24,9 @@ def __init__(
2424
manager: SyncManager,
2525
shutdown_event: Union[mpsync.Event, threading.Event]):
2626
if task.workers > 1:
27-
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot have more than 1 worker")
27+
raise RuntimeError(f"The first task in a pipeline ({task.func}) cannot have more than 1 worker")
2828
if task.join:
29-
raise RuntimeError(f"The first task in a pipeline ({task.func.__qualname__}) cannot join previous results")
29+
raise RuntimeError(f"The first task in a pipeline ({task.func}) cannot join previous results")
3030
self.q_out = manager.Queue(maxsize=task.throttle) \
3131
if task.multiprocess or (next_task is not None and next_task.multiprocess) \
3232
else queue.Queue(maxsize=task.throttle)
@@ -69,7 +69,8 @@ def __init__(
6969
self._n_consumers = 1 if next_task is None else next_task.workers
7070
self._dequeue = DequeueFactory(q_in, task)
7171
self._enqueue = EnqueueFactory(self.q_out, task)
72-
self._workers_done = MultiprocessingCounter(0, manager) if task.multiprocess else ThreadingCounter(0)
72+
self._workers_done = manager.Value('i', 0) if task.multiprocess else SimpleNamespace(value=0)
73+
self._workers_done_lock = manager.Lock() if task.multiprocess else threading.Lock()
7374

7475
def _worker(self):
7576
try:
@@ -80,11 +81,12 @@ def _worker(self):
8081
self._shutdown_event.set()
8182
raise
8283
finally:
83-
if self._workers_done.increment() == self._n_workers:
84-
for _ in range(self._n_consumers):
85-
self.q_out.put(StopSentinel)
84+
with self._workers_done_lock:
85+
self._workers_done.value += 1
86+
if self._workers_done.value == self._n_workers:
87+
for _ in range(self._n_consumers):
88+
self.q_out.put(StopSentinel)
8689

87-
8890
def start(self, pool: WorkerPool, /):
8991
for _ in range(self._n_workers):
9092
pool.submit(self._worker)

src/pyper/_core/util/counter.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

src/pyper/_core/util/worker_pool.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ def __exit__(self, et, ev, tb):
2727
for future in self._futures:
2828
future.result()
2929

30-
def create_queue(self):
31-
raise NotImplementedError
32-
3330
def submit(self, func, /, *args, **kwargs):
3431
future = self._executor.submit(func, *args, **kwargs)
3532
self._futures.append(future)

0 commit comments

Comments
 (0)