Skip to content

Commit 0b0419d

Browse files
Refactor run query output handling (#1373)
Co-authored-by: Thomas Kunwar <[email protected]>
1 parent 6c17b68 commit 0b0419d

File tree

2 files changed

+80
-25
lines changed

2 files changed

+80
-25
lines changed

src/datachain/catalog/catalog.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -133,19 +133,26 @@ def shutdown_process(
133133
return proc.wait()
134134

135135

136-
def _process_stream(stream: "IO[bytes]", callback: Callable[[str], None]) -> None:
136+
def process_output(stream: IO[bytes], callback: Callable[[str], None]) -> None:
137137
buffer = b""
138-
while byt := stream.read(1): # Read one byte at a time
139-
buffer += byt
140138

141-
if byt in (b"\n", b"\r"): # Check for newline or carriage return
142-
line = buffer.decode("utf-8")
143-
callback(line)
144-
buffer = b"" # Clear buffer for next line
139+
try:
140+
while byt := stream.read(1): # Read one byte at a time
141+
buffer += byt
145142

146-
if buffer: # Handle any remaining data in the buffer
147-
line = buffer.decode("utf-8")
148-
callback(line)
143+
if byt in (b"\n", b"\r"): # Check for newline or carriage return
144+
line = buffer.decode("utf-8", errors="replace")
145+
callback(line)
146+
buffer = b"" # Clear buffer for the next line
147+
148+
if buffer: # Handle any remaining data in the buffer
149+
line = buffer.decode("utf-8", errors="replace")
150+
callback(line)
151+
finally:
152+
try:
153+
stream.close() # Ensure output is closed
154+
except Exception: # noqa: BLE001, S110
155+
pass
149156

150157

151158
class DatasetRowsFetcher(NodesThreadPool):
@@ -1747,13 +1754,13 @@ def clone(
17471754
recursive=recursive,
17481755
)
17491756

1757+
@staticmethod
17501758
def query(
1751-
self,
17521759
query_script: str,
17531760
env: Mapping[str, str] | None = None,
17541761
python_executable: str = sys.executable,
1755-
capture_output: bool = False,
1756-
output_hook: Callable[[str], None] = noop,
1762+
stdout_callback: Callable[[str], None] | None = None,
1763+
stderr_callback: Callable[[str], None] | None = None,
17571764
params: dict[str, str] | None = None,
17581765
job_id: str | None = None,
17591766
reset: bool = False,
@@ -1773,13 +1780,18 @@ def query(
17731780
},
17741781
)
17751782
popen_kwargs: dict[str, Any] = {}
1776-
if capture_output:
1777-
popen_kwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.STDOUT}
1783+
1784+
if stdout_callback is not None:
1785+
popen_kwargs = {"stdout": subprocess.PIPE}
1786+
if stderr_callback is not None:
1787+
popen_kwargs["stderr"] = subprocess.PIPE
17781788

17791789
def raise_termination_signal(sig: int, _: Any) -> NoReturn:
17801790
raise TerminationSignal(sig)
17811791

1782-
thread: Thread | None = None
1792+
stdout_thread: Thread | None = None
1793+
stderr_thread: Thread | None = None
1794+
17831795
with subprocess.Popen(cmd, env=env, **popen_kwargs) as proc: # noqa: S603
17841796
logger.info("Starting process %s", proc.pid)
17851797

@@ -1793,10 +1805,20 @@ def raise_termination_signal(sig: int, _: Any) -> NoReturn:
17931805
orig_sigterm_handler = signal.getsignal(signal.SIGTERM)
17941806
signal.signal(signal.SIGTERM, raise_termination_signal)
17951807
try:
1796-
if capture_output:
1797-
args = (proc.stdout, output_hook)
1798-
thread = Thread(target=_process_stream, args=args, daemon=True)
1799-
thread.start()
1808+
if stdout_callback is not None:
1809+
stdout_thread = Thread(
1810+
target=process_output,
1811+
args=(proc.stdout, stdout_callback),
1812+
daemon=True,
1813+
)
1814+
stdout_thread.start()
1815+
if stderr_callback is not None:
1816+
stderr_thread = Thread(
1817+
target=process_output,
1818+
args=(proc.stderr, stderr_callback),
1819+
daemon=True,
1820+
)
1821+
stderr_thread.start()
18001822

18011823
proc.wait()
18021824
except TerminationSignal as exc:
@@ -1814,8 +1836,22 @@ def raise_termination_signal(sig: int, _: Any) -> NoReturn:
18141836
finally:
18151837
signal.signal(signal.SIGTERM, orig_sigterm_handler)
18161838
signal.signal(signal.SIGINT, orig_sigint_handler)
1817-
if thread:
1818-
thread.join() # wait for the reader thread
1839+
# wait for the reader thread
1840+
thread_join_timeout_seconds = 30
1841+
if stdout_thread is not None:
1842+
stdout_thread.join(timeout=thread_join_timeout_seconds)
1843+
if stdout_thread.is_alive():
1844+
logger.warning(
1845+
"stdout thread is still alive after %s seconds",
1846+
thread_join_timeout_seconds,
1847+
)
1848+
if stderr_thread is not None:
1849+
stderr_thread.join(timeout=thread_join_timeout_seconds)
1850+
if stderr_thread.is_alive():
1851+
logger.warning(
1852+
"stderr thread is still alive after %s seconds",
1853+
thread_join_timeout_seconds,
1854+
)
18191855

18201856
logger.info("Process %s exited with return code %s", proc.pid, proc.returncode)
18211857
if proc.returncode in (

tests/unit/test_query.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,31 @@ def test_args(catalog, mock_popen, reset):
5353
mock_popen.assert_called_once_with(["mypython", "-c", "pass"], env=expected_env)
5454

5555

56+
def test_capture_stdout(catalog, mock_popen):
57+
mock_popen.stdout = io.BytesIO(b"Hello, World!\rLorem Ipsum\nDolor Sit Amet\nconse")
58+
stdout = []
59+
60+
catalog.query("pass", stdout_callback=stdout.append)
61+
assert stdout == ["Hello, World!\r", "Lorem Ipsum\n", "Dolor Sit Amet\n", "conse"]
62+
63+
64+
def test_capture_stderr(catalog, mock_popen):
65+
mock_popen.stderr = io.BytesIO(b"Hello, World!\rLorem Ipsum\nDolor Sit Amet\nconse")
66+
stderr = []
67+
68+
catalog.query("pass", stderr_callback=stderr.append)
69+
assert stderr == ["Hello, World!\r", "Lorem Ipsum\n", "Dolor Sit Amet\n", "conse"]
70+
71+
5672
def test_capture_output(catalog, mock_popen):
5773
mock_popen.stdout = io.BytesIO(b"Hello, World!\rLorem Ipsum\nDolor Sit Amet\nconse")
58-
lines = []
74+
mock_popen.stderr = io.BytesIO(b"foo\nbar")
75+
stdout = []
76+
stderr = []
5977

60-
catalog.query("pass", capture_output=True, output_hook=lines.append)
61-
assert lines == ["Hello, World!\r", "Lorem Ipsum\n", "Dolor Sit Amet\n", "conse"]
78+
catalog.query("pass", stdout_callback=stdout.append, stderr_callback=stderr.append)
79+
assert stdout == ["Hello, World!\r", "Lorem Ipsum\n", "Dolor Sit Amet\n", "conse"]
80+
assert stderr == ["foo\n", "bar"]
6281

6382

6483
def test_canceled_by_user(catalog, mock_popen):

0 commit comments

Comments
 (0)