11from __future__ import annotations
22
3- from typing import TYPE_CHECKING
4- import queue
3+ from typing import TYPE_CHECKING , Union
54
65from .stage import Producer , ProducerConsumer
76from ..util .sentinel import StopSentinel
87from ..util .worker_pool import ProcessPool , ThreadPool
98
109if TYPE_CHECKING :
10+ import multiprocessing as mp
11+ import queue
1112 from ..pipeline import Pipeline
1213
1314
1415class PipelineOutput :
1516 def __init__ (self , pipeline : Pipeline ):
1617 self .pipeline = pipeline
1718
18- def _get_q_out (self , tp : ThreadPool , pp : ProcessPool , * args , ** kwargs ) -> queue .Queue :
19+ def _get_q_out (self , tp : ThreadPool , pp : ProcessPool , * args , ** kwargs ) -> Union [ mp . Queue , queue .Queue ] :
1920 """Feed forward each stage to the next, returning the output queue of the final stage."""
2021 q_out = None
2122 for task , next_task in zip (self .pipeline .tasks , self .pipeline .tasks [1 :] + [None ]):
@@ -34,5 +35,10 @@ def __call__(self, *args, **kwargs):
3435 """Iterate through the pipeline, taking the inputs to the first task, and yielding each output from the last task."""
3536 with ThreadPool () as tp , ProcessPool () as pp :
3637 q_out = self ._get_q_out (tp , pp , * args , ** kwargs )
37- while (data := q_out .get ()) is not StopSentinel :
38- yield data
38+ try :
39+ while (data := q_out .get ()) is not StopSentinel :
40+ yield data
41+ except (KeyboardInterrupt , SystemExit ): # pragma: no cover
42+ tp .shutdown_event .set ()
43+ pp .shutdown_event .set ()
44+ raise
0 commit comments