11from __future__ import annotations
22
33import asyncio
4+ import sys
45from typing import TYPE_CHECKING
56
67from .stage import AsyncProducer , AsyncProducerConsumer
78from ..util .sentinel import StopSentinel
89from ..util .thread_pool import ThreadPool
910
11+ if sys .version_info < (3 , 11 ): # pragma: no cover
12+ from ..util .task_group import TaskGroup , ExceptionGroup
13+ else :
14+ from asyncio import TaskGroup
15+
1016if TYPE_CHECKING :
1117 from ..pipeline import AsyncPipeline
1218
@@ -15,7 +21,7 @@ class AsyncPipelineOutput:
1521 def __init__ (self , pipeline : AsyncPipeline ):
1622 self .pipeline = pipeline
1723
18- def _get_q_out (self , tg : asyncio . TaskGroup , tp : ThreadPool , pp : ProcessPool , * args , ** kwargs ) -> asyncio .Queue :
24+ def _get_q_out (self , tg : TaskGroup , tp : ThreadPool , pp : ProcessPool , * args , ** kwargs ) -> asyncio .Queue :
1925 """Feed forward each stage to the next, returning the output queue of the final stage."""
2026 q_out = None
2127 for task , next_task in zip (self .pipeline .tasks , self .pipeline .tasks [1 :] + [None ]):
@@ -33,12 +39,20 @@ def _get_q_out(self, tg: asyncio.TaskGroup, tp: ThreadPool, pp: ProcessPool, *ar
3339 async def __call__ (self , * args , ** kwargs ):
3440 """Call the pipeline, taking the inputs to the first task, and returning the output from the last task."""
3541 try :
42+ < << << << HEAD
3643 # Unify async, threaded, and multiprocessed work by:
3744 # 1. using TaskGroup to execute asynchronous tasks
3845 # 2. using ThreadPool to execute threaded synchronous tasks
3946 # 3. using ProcessPool to execute multiprocessed synchronous tasks
4047 async with asyncio .TaskGroup () as tg , ThreadPool () as tp , ProcessPool as pp :
4148 q_out = self ._get_q_out (tg , tp , pp , * args , ** kwargs )
49+ == == == =
50+ # We unify async and thread-based concurrency by
51+ # 1. using TaskGroup to spin up asynchronous tasks
52+ # 2. using ThreadPool to spin up synchronous tasks
53+ async with TaskGroup () as tg , ThreadPool () as tp :
54+ q_out = self ._get_q_out (tg , tp , * args , ** kwargs )
55+ >> >> >> > dev
4256 while (data := await q_out .get ()) is not StopSentinel :
4357 yield data
4458 except ExceptionGroup as eg :
0 commit comments