Skip to content

Commit 406a0a1

Browse files
authored
fix(dispatcher): simplify and fix termination handling (#1384)
1 parent b2d9fcc commit 406a0a1

File tree

3 files changed

+182
-43
lines changed

3 files changed

+182
-43
lines changed

src/datachain/query/dispatch.py

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
from collections.abc import Iterable, Sequence
33
from itertools import chain
44
from multiprocessing import cpu_count
5+
from queue import Empty
56
from sys import stdin
7+
from time import monotonic, sleep
68
from typing import TYPE_CHECKING, Literal
79

10+
import multiprocess
811
from cloudpickle import load, loads
912
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
10-
from multiprocess import get_context
13+
from multiprocess.context import Process
14+
from multiprocess.queues import Queue as MultiprocessQueue
1115

1216
from datachain.catalog import Catalog
1317
from datachain.catalog.catalog import clone_catalog_with_cache
@@ -25,7 +29,6 @@
2529
from datachain.utils import batched, flatten, safe_closing
2630

2731
if TYPE_CHECKING:
28-
import multiprocess
2932
from sqlalchemy import Select, Table
3033

3134
from datachain.data_storage import AbstractMetastore, AbstractWarehouse
@@ -101,8 +104,8 @@ def udf_worker_entrypoint(fd: int | None = None) -> int:
101104

102105
class UDFDispatcher:
103106
_catalog: Catalog | None = None
104-
task_queue: "multiprocess.Queue | None" = None
105-
done_queue: "multiprocess.Queue | None" = None
107+
task_queue: MultiprocessQueue | None = None
108+
done_queue: MultiprocessQueue | None = None
106109

107110
def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
108111
self.udf_data = udf_info["udf_data"]
@@ -121,7 +124,7 @@ def __init__(self, udf_info: UdfInfo, buffer_size: int = DEFAULT_BATCH_SIZE):
121124
self.buffer_size = buffer_size
122125
self.task_queue = None
123126
self.done_queue = None
124-
self.ctx = get_context("spawn")
127+
self.ctx = multiprocess.get_context("spawn")
125128

126129
@property
127130
def catalog(self) -> "Catalog":
@@ -259,8 +262,6 @@ def run_udf_parallel( # noqa: C901, PLR0912
259262
for p in pool:
260263
p.start()
261264

262-
# Will be set to True if all tasks complete normally
263-
normal_completion = False
264265
try:
265266
# Will be set to True when the input is exhausted
266267
input_finished = False
@@ -283,10 +284,20 @@ def run_udf_parallel( # noqa: C901, PLR0912
283284

284285
# Process all tasks
285286
while n_workers > 0:
286-
try:
287-
result = get_from_queue(self.done_queue)
288-
except KeyboardInterrupt:
289-
break
287+
while True:
288+
try:
289+
result = self.done_queue.get_nowait()
290+
break
291+
except Empty:
292+
for p in pool:
293+
exitcode = p.exitcode
294+
if exitcode not in (None, 0):
295+
message = (
296+
f"Worker {p.name} exited unexpectedly with "
297+
f"code {exitcode}"
298+
)
299+
raise RuntimeError(message) from None
300+
sleep(0.01)
290301

291302
if bytes_downloaded := result.get("bytes_downloaded"):
292303
download_cb.relative_update(bytes_downloaded)
@@ -313,39 +324,50 @@ def run_udf_parallel( # noqa: C901, PLR0912
313324
put_into_queue(self.task_queue, next(input_data))
314325
except StopIteration:
315326
input_finished = True
316-
317-
# Finished with all tasks normally
318-
normal_completion = True
319327
finally:
320-
if not normal_completion:
321-
# Stop all workers if there is an unexpected exception
322-
for _ in pool:
323-
put_into_queue(self.task_queue, STOP_SIGNAL)
324-
325-
# This allows workers (and this process) to exit without
326-
# consuming any remaining data in the queues.
327-
# (If they exit due to an exception.)
328-
self.task_queue.close()
329-
self.task_queue.join_thread()
330-
331-
# Flush all items from the done queue.
332-
# This is needed if any workers are still running.
333-
while n_workers > 0:
334-
result = get_from_queue(self.done_queue)
335-
status = result["status"]
336-
if status != OK_STATUS:
337-
n_workers -= 1
338-
339-
self.done_queue.close()
340-
self.done_queue.join_thread()
328+
self._shutdown_workers(pool)
329+
330+
def _shutdown_workers(self, pool: list[Process]) -> None:
331+
self._terminate_pool(pool)
332+
self._drain_queue(self.done_queue)
333+
self._drain_queue(self.task_queue)
334+
self._close_queue(self.done_queue)
335+
self._close_queue(self.task_queue)
336+
337+
def _terminate_pool(self, pool: list[Process]) -> None:
338+
for proc in pool:
339+
if proc.is_alive():
340+
proc.terminate()
341+
342+
deadline = monotonic() + 1.0
343+
for proc in pool:
344+
if not proc.is_alive():
345+
continue
346+
remaining = deadline - monotonic()
347+
if remaining > 0:
348+
proc.join(remaining)
349+
if proc.is_alive():
350+
proc.kill()
351+
proc.join(timeout=0.2)
352+
353+
def _drain_queue(self, queue: MultiprocessQueue) -> None:
354+
while True:
355+
try:
356+
queue.get_nowait()
357+
except Empty:
358+
return
359+
except (OSError, ValueError):
360+
return
341361

342-
# Wait for workers to stop
343-
for p in pool:
344-
p.join()
362+
def _close_queue(self, queue: MultiprocessQueue) -> None:
363+
with contextlib.suppress(OSError, ValueError):
364+
queue.close()
365+
with contextlib.suppress(RuntimeError, AssertionError, ValueError):
366+
queue.join_thread()
345367

346368

347369
class DownloadCallback(Callback):
348-
def __init__(self, queue: "multiprocess.Queue") -> None:
370+
def __init__(self, queue: MultiprocessQueue) -> None:
349371
self.queue = queue
350372
super().__init__()
351373

@@ -360,7 +382,7 @@ class ProcessedCallback(Callback):
360382
def __init__(
361383
self,
362384
name: Literal["processed", "generated"],
363-
queue: "multiprocess.Queue",
385+
queue: MultiprocessQueue,
364386
) -> None:
365387
self.name = name
366388
self.queue = queue
@@ -375,8 +397,8 @@ def __init__(
375397
self,
376398
catalog: "Catalog",
377399
udf: "UDFAdapter",
378-
task_queue: "multiprocess.Queue",
379-
done_queue: "multiprocess.Queue",
400+
task_queue: MultiprocessQueue,
401+
done_queue: MultiprocessQueue,
380402
query: "Select",
381403
table: "Table",
382404
cache: bool,

src/datachain/query/queue.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import datetime
22
from collections.abc import Iterable, Iterator
3-
from queue import Empty, Full, Queue
3+
from queue import Empty, Full
44
from struct import pack, unpack
55
from time import sleep
66
from typing import Any
77

88
import msgpack
9+
from multiprocess.queues import Queue
910

1011
from datachain.query.batch import RowsOutput
1112

tests/func/test_udf.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import os
33
import pickle
44
import posixpath
5+
import sys
6+
import time
57

8+
import multiprocess as mp
69
import pytest
710

811
import datachain as dc
@@ -560,6 +563,119 @@ def name_len_error(_name):
560563
chain.show()
561564

562565

566+
@pytest.mark.parametrize(
567+
"failure_mode,expected_exit_code,error_marker",
568+
[
569+
("exception", 1, "Worker 1 failure!"),
570+
("keyboard_interrupt", -2, "KeyboardInterrupt"),
571+
("sys_exit", 1, None),
572+
("os_exit", 1, None), # os._exit - immediate termination
573+
],
574+
)
575+
def test_udf_parallel_worker_failure_exits_peers(
576+
test_session_tmpfile,
577+
tmp_path,
578+
capfd,
579+
failure_mode,
580+
expected_exit_code,
581+
error_marker,
582+
):
583+
"""
584+
Test that when one worker fails, all other workers exit immediately.
585+
586+
Tests different failure modes:
587+
- exception: Worker raises RuntimeError (normal exception)
588+
- keyboard_interrupt: Worker raises KeyboardInterrupt (simulates Ctrl+C)
589+
- sys_exit: Worker calls sys.exit() (clean Python exit)
590+
- os_exit: Worker calls os._exit() (immediate process termination)
591+
"""
592+
import platform
593+
594+
# Windows uses different exit codes for KeyboardInterrupt
595+
# 3221225786 (0xC000013A) is STATUS_CONTROL_C_EXIT on Windows
596+
# while POSIX systems use -2 (SIGINT)
597+
if platform.system() == "Windows" and failure_mode == "keyboard_interrupt":
598+
expected_exit_code = 3221225786
599+
600+
vals = list(range(100))
601+
602+
barrier_dir = tmp_path / "udf_workers_barrier"
603+
barrier_dir_str = str(barrier_dir)
604+
os.makedirs(barrier_dir_str, exist_ok=True)
605+
expected_workers = 3
606+
607+
def slow_process(val: int) -> int:
608+
proc_name = mp.current_process().name
609+
with open(os.path.join(barrier_dir_str, f"{proc_name}.started"), "w") as f:
610+
f.write(str(time.time()))
611+
612+
# Wait until all expected workers have written their markers
613+
deadline = time.time() + 1.0
614+
while time.time() < deadline:
615+
try:
616+
count = len(
617+
[n for n in os.listdir(barrier_dir_str) if n.endswith(".started")]
618+
)
619+
except FileNotFoundError:
620+
count = 0
621+
if count >= expected_workers:
622+
break
623+
time.sleep(0.01)
624+
625+
if proc_name == "Worker-UDF-1":
626+
if failure_mode == "exception":
627+
raise RuntimeError("Worker 1 failure!")
628+
if failure_mode == "keyboard_interrupt":
629+
raise KeyboardInterrupt("Worker interrupted")
630+
if failure_mode == "sys_exit":
631+
sys.exit(1)
632+
if failure_mode == "os_exit":
633+
os._exit(1)
634+
time.sleep(5)
635+
return val * 2
636+
637+
chain = (
638+
dc.read_values(val=vals, session=test_session_tmpfile)
639+
.settings(parallel=3)
640+
.map(slow_process, output={"result": int})
641+
)
642+
643+
start = time.time()
644+
with pytest.raises(RuntimeError, match="UDF Execution Failed!") as exc_info:
645+
list(chain.to_iter("result"))
646+
elapsed = time.time() - start
647+
648+
# Verify timing: should exit immediately when worker fails
649+
assert elapsed < 10, f"took {elapsed:.1f}s, should exit immediately"
650+
651+
# Verify multiple workers were started via barrier markers
652+
try:
653+
started_files = [
654+
n for n in os.listdir(barrier_dir_str) if n.endswith(".started")
655+
]
656+
except FileNotFoundError:
657+
started_files = []
658+
assert len(started_files) == 3, (
659+
f"Expected all 3 workers to start, but saw markers for: {started_files}"
660+
)
661+
662+
captured = capfd.readouterr()
663+
664+
# Verify the RuntimeError has a meaningful message with exit code
665+
error_message = str(exc_info.value)
666+
assert f"UDF Execution Failed! Exit code: {expected_exit_code}" in error_message, (
667+
f"Expected exit code {expected_exit_code}, got: {error_message}"
668+
)
669+
670+
if error_marker:
671+
assert error_marker in captured.err, (
672+
f"Expected '{error_marker}' in stderr for {failure_mode} mode. "
673+
f"stderr output: {captured.err[:500]}"
674+
)
675+
676+
assert "semaphore" not in captured.err
677+
678+
563679
@pytest.mark.parametrize(
564680
"cloud_type,version_aware",
565681
[("s3", True)],

0 commit comments

Comments
 (0)