Skip to content

Commit 13af52c

Browse files
authored
fix buffer full deadlock (#2929)
1 parent 2bb5669 commit 13af52c

File tree

2 files changed

+118
-19
lines changed

2 files changed

+118
-19
lines changed

src/_nebari/utils.py

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import re
77
import secrets
8+
import selectors
89
import signal
910
import string
1011
import subprocess
@@ -47,6 +48,77 @@ def change_directory(directory):
4748
os.chdir(current_directory)
4849

4950

51+
def strip_ansi_errors(line):
52+
"""Strips ANSI escape codes from a string."""
53+
ansi_escape = re.compile(rb"\x1b\[[0-9;]*[mK]")
54+
return ansi_escape.sub(b"", line)
55+
56+
57+
def process_streams(
58+
process, line_prefix, strip_errors, print_stdout=True, print_stderr=True
59+
):
60+
sel = selectors.DefaultSelector()
61+
sel.register(process.stdout, selectors.EVENT_READ, data="stdout")
62+
if process.stderr and process.stderr != process.stdout:
63+
sel.register(process.stderr, selectors.EVENT_READ, data="stderr")
64+
65+
outputs = {"stdout": [], "stderr": []}
66+
partial = {"stdout": b"", "stderr": b""}
67+
68+
try:
69+
while True:
70+
events = sel.select(timeout=0.1)
71+
if not events and process.poll() is not None:
72+
# Handle any remaining partial output
73+
for stream_name in ["stdout", "stderr"]:
74+
if partial[stream_name]:
75+
line = partial[stream_name]
76+
if strip_errors:
77+
line = strip_ansi_errors(line)
78+
outputs[stream_name].append(line)
79+
break
80+
81+
for key, _ in events:
82+
data = key.fileobj.read1(8192)
83+
if not data:
84+
sel.unregister(key.fileobj)
85+
continue
86+
87+
stream_name = key.data
88+
chunk = partial[stream_name] + data
89+
lines = chunk.split(b"\n")
90+
partial[stream_name] = lines[-1]
91+
92+
for line in lines[:-1]:
93+
line_w_newline = line + b"\n"
94+
if strip_errors:
95+
line_w_newline = strip_ansi_errors(line_w_newline)
96+
97+
# Handle stdout
98+
if stream_name == "stdout":
99+
if print_stdout:
100+
sys.stdout.buffer.write(line_prefix + line_w_newline)
101+
sys.stdout.flush()
102+
else:
103+
outputs["stdout"].append(line_w_newline)
104+
105+
# Handle stderr
106+
if stream_name == "stderr":
107+
if print_stderr:
108+
sys.stderr.buffer.write(line_prefix + line_w_newline)
109+
sys.stderr.flush()
110+
else:
111+
outputs["stderr"].append(line_w_newline)
112+
finally:
113+
sel.close()
114+
if process.stdout:
115+
process.stdout.close()
116+
if process.stderr:
117+
process.stderr.close()
118+
119+
return outputs["stdout"], outputs["stderr"]
120+
121+
50122
def run_subprocess_cmd(processargs, prefix=b"", capture_output=False, **kwargs):
51123
"""Runs subprocess command with realtime stdout logging with optional line prefix."""
52124
if prefix:
@@ -72,6 +144,7 @@ def run_subprocess_cmd(processargs, prefix=b"", capture_output=False, **kwargs):
72144
stderr=stderr_stream,
73145
preexec_fn=os.setsid,
74146
)
147+
75148
# Set timeout thread
76149
timeout_timer = None
77150
if timeout > 0:
@@ -85,25 +158,14 @@ def kill_process():
85158
timeout_timer = threading.Timer(timeout, kill_process)
86159
timeout_timer.start()
87160

88-
print_stream = process.stderr if capture_output else process.stdout
89-
for line in iter(lambda: print_stream.readline(), b""):
90-
full_line = line_prefix + line
91-
if strip_errors:
92-
full_line = full_line.decode("utf-8")
93-
full_line = re.sub(
94-
r"\x1b\[31m", "", full_line
95-
) # Remove red ANSI escape code
96-
full_line = full_line.encode("utf-8")
97-
98-
sys.stdout.buffer.write(full_line)
99-
sys.stdout.flush()
100-
print_stream.close()
101-
102-
output = []
103161
if capture_output:
104-
for line in iter(lambda: process.stdout.readline(), b""):
105-
output.append(line)
106-
process.stdout.close()
162+
output, _ = process_streams(
163+
process, line_prefix, strip_errors, print_stdout=False, print_stderr=True
164+
)
165+
else:
166+
process_streams(
167+
process, line_prefix, strip_errors, print_stdout=True, print_stderr=True
168+
)
107169

108170
if timeout_timer is not None:
109171
timeout_timer.cancel()

tests/tests_unit/test_utils.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1+
import sys
2+
13
import pytest
24

3-
from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion, deep_merge
5+
from _nebari.utils import (
6+
JsonDiff,
7+
JsonDiffEnum,
8+
byte_unit_conversion,
9+
deep_merge,
10+
run_subprocess_cmd,
11+
)
412

513

614
@pytest.mark.parametrize(
@@ -136,3 +144,32 @@ def test_deep_merge_empty():
136144

137145
result = deep_merge()
138146
assert result == expected_result
147+
148+
149+
size_kb_end_args = [
150+
(1, ""), # 1KB no newline
151+
(1, "\\n"), # 1KB with newline
152+
(64, ""), # 64KB no newline
153+
(64, "\\n"), # 64KB with newline
154+
(128, ""), # 128KB no newline
155+
(128, "\\n"), # 128KB with newline
156+
]
157+
158+
159+
@pytest.mark.parametrize(
160+
"size_kb,end",
161+
size_kb_end_args,
162+
ids=[
163+
f"{params[0]}KB{'_newline' if params[1] else ''}" for params in size_kb_end_args
164+
],
165+
)
166+
def test_run_subprocess_cmd(size_kb, end):
167+
"""Test large output handling using current Python interpreter"""
168+
python_exe = sys.executable
169+
command = [python_exe, "-c", f"print('a' * {size_kb} * 1024, end='{end}')"]
170+
171+
exit_code, output = run_subprocess_cmd(
172+
command, capture_output=True, strip_errors=True, timeout=1
173+
)
174+
assert exit_code == 0
175+
assert len(output.decode()) == size_kb * 1024 + (1 if end else 0)

0 commit comments

Comments
 (0)