22from collections .abc import Iterable , Sequence
33from itertools import chain
44from multiprocessing import cpu_count
5+ from queue import Empty
56from sys import stdin
7+ from time import monotonic , sleep
68from typing import TYPE_CHECKING , Literal
79
10+ import multiprocess
811from cloudpickle import load , loads
912from fsspec .callbacks import DEFAULT_CALLBACK , Callback
10- from multiprocess import get_context
13+ from multiprocess .context import Process
14+ from multiprocess .queues import Queue as MultiprocessQueue
1115
1216from datachain .catalog import Catalog
1317from datachain .catalog .catalog import clone_catalog_with_cache
2529from datachain .utils import batched , flatten , safe_closing
2630
2731if TYPE_CHECKING :
28- import multiprocess
2932 from sqlalchemy import Select , Table
3033
3134 from datachain .data_storage import AbstractMetastore , AbstractWarehouse
@@ -101,8 +104,8 @@ def udf_worker_entrypoint(fd: int | None = None) -> int:
101104
102105class UDFDispatcher :
103106 _catalog : Catalog | None = None
104- task_queue : "multiprocess.Queue | None" = None
105- done_queue : "multiprocess.Queue | None" = None
107+ task_queue : MultiprocessQueue | None = None
108+ done_queue : MultiprocessQueue | None = None
106109
107110 def __init__ (self , udf_info : UdfInfo , buffer_size : int = DEFAULT_BATCH_SIZE ):
108111 self .udf_data = udf_info ["udf_data" ]
@@ -121,7 +124,7 @@ def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
121124 self .buffer_size = buffer_size
122125 self .task_queue = None
123126 self .done_queue = None
124- self .ctx = get_context ("spawn" )
127+ self .ctx = multiprocess . get_context ("spawn" )
125128
126129 @property
127130 def catalog (self ) -> "Catalog" :
@@ -259,8 +262,6 @@ def run_udf_parallel( # noqa: C901, PLR0912
259262 for p in pool :
260263 p .start ()
261264
262- # Will be set to True if all tasks complete normally
263- normal_completion = False
264265 try :
265266 # Will be set to True when the input is exhausted
266267 input_finished = False
@@ -283,10 +284,20 @@ def run_udf_parallel( # noqa: C901, PLR0912
283284
284285 # Process all tasks
285286 while n_workers > 0 :
286- try :
287- result = get_from_queue (self .done_queue )
288- except KeyboardInterrupt :
289- break
287+ while True :
288+ try :
289+ result = self .done_queue .get_nowait ()
290+ break
291+ except Empty :
292+ for p in pool :
293+ exitcode = p .exitcode
294+ if exitcode not in (None , 0 ):
295+ message = (
296+ f"Worker { p .name } exited unexpectedly with "
297+ f"code { exitcode } "
298+ )
299+ raise RuntimeError (message ) from None
300+ sleep (0.01 )
290301
291302 if bytes_downloaded := result .get ("bytes_downloaded" ):
292303 download_cb .relative_update (bytes_downloaded )
@@ -313,39 +324,50 @@ def run_udf_parallel( # noqa: C901, PLR0912
313324 put_into_queue (self .task_queue , next (input_data ))
314325 except StopIteration :
315326 input_finished = True
316-
317- # Finished with all tasks normally
318- normal_completion = True
319327 finally :
320- if not normal_completion :
321- # Stop all workers if there is an unexpected exception
322- for _ in pool :
323- put_into_queue (self .task_queue , STOP_SIGNAL )
324-
325- # This allows workers (and this process) to exit without
326- # consuming any remaining data in the queues.
327- # (If they exit due to an exception.)
328- self .task_queue .close ()
329- self .task_queue .join_thread ()
330-
331- # Flush all items from the done queue.
332- # This is needed if any workers are still running.
333- while n_workers > 0 :
334- result = get_from_queue (self .done_queue )
335- status = result ["status" ]
336- if status != OK_STATUS :
337- n_workers -= 1
338-
339- self .done_queue .close ()
340- self .done_queue .join_thread ()
328+ self ._shutdown_workers (pool )
329+
330+ def _shutdown_workers (self , pool : list [Process ]) -> None :
331+ self ._terminate_pool (pool )
332+ self ._drain_queue (self .done_queue )
333+ self ._drain_queue (self .task_queue )
334+ self ._close_queue (self .done_queue )
335+ self ._close_queue (self .task_queue )
336+
337+ def _terminate_pool (self , pool : list [Process ]) -> None :
338+ for proc in pool :
339+ if proc .is_alive ():
340+ proc .terminate ()
341+
342+ deadline = monotonic () + 1.0
343+ for proc in pool :
344+ if not proc .is_alive ():
345+ continue
346+ remaining = deadline - monotonic ()
347+ if remaining > 0 :
348+ proc .join (remaining )
349+ if proc .is_alive ():
350+ proc .kill ()
351+ proc .join (timeout = 0.2 )
352+
353+ def _drain_queue (self , queue : MultiprocessQueue ) -> None :
354+ while True :
355+ try :
356+ queue .get_nowait ()
357+ except Empty :
358+ return
359+ except (OSError , ValueError ):
360+ return
341361
342- # Wait for workers to stop
343- for p in pool :
344- p .join ()
362+ def _close_queue (self , queue : MultiprocessQueue ) -> None :
363+ with contextlib .suppress (OSError , ValueError ):
364+ queue .close ()
365+ with contextlib .suppress (RuntimeError , AssertionError , ValueError ):
366+ queue .join_thread ()
345367
346368
347369class DownloadCallback (Callback ):
348- def __init__ (self , queue : "multiprocess.Queue" ) -> None :
370+ def __init__ (self , queue : MultiprocessQueue ) -> None :
349371 self .queue = queue
350372 super ().__init__ ()
351373
@@ -360,7 +382,7 @@ class ProcessedCallback(Callback):
360382 def __init__ (
361383 self ,
362384 name : Literal ["processed" , "generated" ],
363- queue : "multiprocess.Queue" ,
385+ queue : MultiprocessQueue ,
364386 ) -> None :
365387 self .name = name
366388 self .queue = queue
@@ -375,8 +397,8 @@ def __init__(
375397 self ,
376398 catalog : "Catalog" ,
377399 udf : "UDFAdapter" ,
378- task_queue : "multiprocess.Queue" ,
379- done_queue : "multiprocess.Queue" ,
400+ task_queue : MultiprocessQueue ,
401+ done_queue : MultiprocessQueue ,
380402 query : "Select" ,
381403 table : "Table" ,
382404 cache : bool ,
0 commit comments