11from __future__ import annotations
22
33import asyncio
4+ from concurrent .futures import ProcessPoolExecutor
45import sys
56from typing import TYPE_CHECKING
67
@@ -21,7 +22,7 @@ class AsyncPipelineOutput:
2122 def __init__ (self , pipeline : AsyncPipeline ):
2223 self .pipeline = pipeline
2324
24- def _get_q_out (self , tg : TaskGroup , tp : ThreadPool , pp : ProcessPool , * args , ** kwargs ) -> asyncio .Queue :
25+ def _get_q_out (self , tg : TaskGroup , tp : ThreadPool , pp : ProcessPoolExecutor , * args , ** kwargs ) -> asyncio .Queue :
2526 """Feed forward each stage to the next, returning the output queue of the final stage."""
2627 q_out = None
2728 for task , next_task in zip (self .pipeline .tasks , self .pipeline .tasks [1 :] + [None ]):
@@ -39,22 +40,15 @@ def _get_q_out(self, tg: TaskGroup, tp: ThreadPool, pp: ProcessPool, *args, **kw
3940 async def __call__ (self , * args , ** kwargs ):
4041 """Call the pipeline, taking the inputs to the first task, and returning the output from the last task."""
4142 try :
42- < << << << HEAD
4343 # Unify async, threaded, and multiprocessed work by:
4444 # 1. using TaskGroup to execute asynchronous tasks
4545 # 2. using ThreadPool to execute threaded synchronous tasks
46- # 3. using ProcessPool to execute multiprocessed synchronous tasks
47- async with asyncio .TaskGroup () as tg , ThreadPool () as tp , ProcessPool as pp :
48- q_out = self ._get_q_out (tg , tp , pp , * args , ** kwargs )
49- == == == =
50- # We unify async and thread-based concurrency by
51- # 1. using TaskGroup to spin up asynchronous tasks
52- # 2. using ThreadPool to spin up synchronous tasks
53- async with TaskGroup () as tg , ThreadPool () as tp :
54- q_out = self ._get_q_out (tg , tp , * args , ** kwargs )
55- >> >> >> > dev
56- while (data := await q_out .get ()) is not StopSentinel :
57- yield data
46+ # 3. using ProcessPoolExecutor to execute multiprocessed synchronous tasks
47+ async with TaskGroup () as tg :
48+ with ThreadPool () as tp , ProcessPoolExecutor () as pp :
49+ q_out = self ._get_q_out (tg , tp , pp , * args , ** kwargs )
50+ while (data := await q_out .get ()) is not StopSentinel :
51+ yield data
5852 except ExceptionGroup as eg :
59- raise eg .exceptions [0 ]
53+ raise eg .exceptions [0 ] from None
6054
0 commit comments