Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 80 additions & 18 deletions src/_nebari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import secrets
import selectors
import signal
import string
import subprocess
Expand Down Expand Up @@ -46,6 +47,77 @@ def change_directory(directory):
os.chdir(current_directory)


def strip_ansi_errors(line):
"""Strips ANSI escape codes from a string."""
ansi_escape = re.compile(rb"\x1b\[[0-9;]*[mK]")
return ansi_escape.sub(b"", line)


def process_streams(
process, line_prefix, strip_errors, print_stdout=True, print_stderr=True
):
sel = selectors.DefaultSelector()
sel.register(process.stdout, selectors.EVENT_READ, data="stdout")
if process.stderr and process.stderr != process.stdout:
sel.register(process.stderr, selectors.EVENT_READ, data="stderr")

outputs = {"stdout": [], "stderr": []}
partial = {"stdout": b"", "stderr": b""}

try:
while True:
events = sel.select(timeout=0.1)
if not events and process.poll() is not None:
# Handle any remaining partial output
for stream_name in ["stdout", "stderr"]:
if partial[stream_name]:
line = partial[stream_name]
if strip_errors:
line = strip_ansi_errors(line)
outputs[stream_name].append(line)
break

for key, _ in events:
data = key.fileobj.read1(8192)
if not data:
sel.unregister(key.fileobj)
continue

stream_name = key.data
chunk = partial[stream_name] + data
lines = chunk.split(b"\n")
partial[stream_name] = lines[-1]

for line in lines[:-1]:
line_w_newline = line + b"\n"
if strip_errors:
line_w_newline = strip_ansi_errors(line_w_newline)

# Handle stdout
if stream_name == "stdout":
if print_stdout:
sys.stdout.buffer.write(line_prefix + line_w_newline)
sys.stdout.flush()
else:
outputs["stdout"].append(line_w_newline)

# Handle stderr
if stream_name == "stderr":
if print_stderr:
sys.stderr.buffer.write(line_prefix + line_w_newline)
sys.stderr.flush()
else:
outputs["stderr"].append(line_w_newline)
finally:
sel.close()
if process.stdout:
process.stdout.close()
if process.stderr:
process.stderr.close()

return outputs["stdout"], outputs["stderr"]


def run_subprocess_cmd(processargs, prefix=b"", capture_output=False, **kwargs):
"""Runs subprocess command with realtime stdout logging with optional line prefix."""
if prefix:
Expand All @@ -71,6 +143,7 @@ def run_subprocess_cmd(processargs, prefix=b"", capture_output=False, **kwargs):
stderr=stderr_stream,
preexec_fn=os.setsid,
)

# Set timeout thread
timeout_timer = None
if timeout > 0:
Expand All @@ -84,25 +157,14 @@ def kill_process():
timeout_timer = threading.Timer(timeout, kill_process)
timeout_timer.start()

print_stream = process.stderr if capture_output else process.stdout
for line in iter(lambda: print_stream.readline(), b""):
full_line = line_prefix + line
if strip_errors:
full_line = full_line.decode("utf-8")
full_line = re.sub(
r"\x1b\[31m", "", full_line
) # Remove red ANSI escape code
full_line = full_line.encode("utf-8")

sys.stdout.buffer.write(full_line)
sys.stdout.flush()
print_stream.close()

output = []
if capture_output:
for line in iter(lambda: process.stdout.readline(), b""):
output.append(line)
process.stdout.close()
output, _ = process_streams(
process, line_prefix, strip_errors, print_stdout=False, print_stderr=True
)
else:
process_streams(
process, line_prefix, strip_errors, print_stdout=True, print_stderr=True
)

if timeout_timer is not None:
timeout_timer.cancel()
Expand Down
39 changes: 38 additions & 1 deletion tests/tests_unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import sys

import pytest

from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion, deep_merge
from _nebari.utils import (
JsonDiff,
JsonDiffEnum,
byte_unit_conversion,
deep_merge,
run_subprocess_cmd,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -136,3 +144,32 @@ def test_deep_merge_empty():

result = deep_merge()
assert result == expected_result


size_kb_end_args = [
(1, ""), # 1KB no newline
(1, "\\n"), # 1KB with newline
(64, ""), # 64KB no newline
(64, "\\n"), # 64KB with newline
(128, ""), # 128KB no newline
(128, "\\n"), # 128KB with newline
]


@pytest.mark.parametrize(
"size_kb,end",
size_kb_end_args,
ids=[
f"{params[0]}KB{'_newline' if params[1] else ''}" for params in size_kb_end_args
],
)
def test_run_subprocess_cmd(size_kb, end):
"""Test large output handling using current Python interpreter"""
python_exe = sys.executable
command = [python_exe, "-c", f"print('a' * {size_kb} * 1024, end='{end}')"]

exit_code, output = run_subprocess_cmd(
command, capture_output=True, strip_errors=True, timeout=1
)
assert exit_code == 0
assert len(output.decode()) == size_kb * 1024 + (1 if end else 0)
Loading