11from __future__ import annotations
22
33import asyncio
4+ import sys
45from typing import TYPE_CHECKING
56
67from .stage import AsyncProducer , AsyncProducerConsumer
78from ..util .sentinel import StopSentinel
89from ..util .thread_pool import ThreadPool
910
11+ if sys .version_info < (3 , 11 ): # pragma: no cover
12+ from ..util .task_group import TaskGroup , ExceptionGroup
13+ else :
14+ from asyncio import TaskGroup
15+
1016if TYPE_CHECKING :
1117 from ..pipeline import AsyncPipeline
1218
@@ -15,7 +21,7 @@ class AsyncPipelineOutput:
1521 def __init__ (self , pipeline : AsyncPipeline ):
1622 self .pipeline = pipeline
1723
18- def _get_q_out (self , tg : asyncio . TaskGroup , tp : ThreadPool ,* args , ** kwargs ) -> asyncio .Queue :
24+ def _get_q_out (self , tg : TaskGroup , tp : ThreadPool ,* args , ** kwargs ) -> asyncio .Queue :
1925 """Feed forward each stage to the next, returning the output queue of the final stage."""
2026 q_out = None
2127 for task , next_task in zip (self .pipeline .tasks , self .pipeline .tasks [1 :] + [None ]):
@@ -36,7 +42,7 @@ async def __call__(self, *args, **kwargs):
3642 # We unify async and thread-based concurrency by
3743 # 1. using TaskGroup to spin up asynchronous tasks
3844 # 2. using ThreadPool to spin up synchronous tasks
39- async with asyncio . TaskGroup () as tg , ThreadPool () as tp :
45+ async with TaskGroup () as tg , ThreadPool () as tp :
4046 q_out = self ._get_q_out (tg , tp , * args , ** kwargs )
4147 while (data := await q_out .get ()) is not StopSentinel :
4248 yield data
0 commit comments