Skip to content

Commit 247c5c0

Browse files
live stream of command output in Shell driver
streams stdout and stderr for long-running methods
1 parent 4503fda commit 247c5c0

File tree

4 files changed

+181
-51
lines changed

4 files changed

+181
-51
lines changed

packages/jumpstarter-driver-shell/jumpstarter_driver_shell/client.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from dataclasses import dataclass
23

34
import click
@@ -13,8 +14,8 @@ class ShellClient(DriverClient):
1314
Client interface for Shell driver.
1415
1516
This client dynamically checks that the method is configured
16-
on the driver, and if it is, it will call it and get the results
17-
in the form of (stdout, stderr, returncode).
17+
on the driver, and if it is, it will call it with live streaming output.
18+
Output chunks are displayed as they arrive.
1819
"""
1920

2021
def _check_method_exists(self, method):
@@ -26,7 +27,17 @@ def _check_method_exists(self, method):
2627
## capture any method calls dynamically
2728
def __getattr__(self, name):
2829
self._check_method_exists(name)
29-
return lambda *args, **kwargs: tuple(self.call("call_method", name, kwargs, *args))
30+
def execute(*args, **kwargs):
31+
returncode = 0
32+
for stdout, stderr, code in self.streamingcall("call_method", name, kwargs, *args):
33+
if stdout:
34+
print(stdout, end='', flush=True)
35+
if stderr:
36+
print(stderr, end='', file=sys.stderr, flush=True)
37+
if code is not None:
38+
returncode = code
39+
return returncode
40+
return execute
3041

3142
def cli(self):
3243
"""Create CLI interface for dynamically configured shell methods"""
@@ -64,14 +75,7 @@ def method_command(args, env):
6475
else:
6576
raise click.BadParameter(f"Invalid --env value '{env_var}'. Use KEY=VALUE.")
6677

67-
# Call the method
68-
stdout, stderr, returncode = self.call("call_method", method_name, env_dict, *args)
69-
70-
# Display results
71-
if stdout:
72-
click.echo(stdout, nl=not stdout.endswith("\n"))
73-
if stderr:
74-
click.echo(stderr, err=True, nl=not stderr.endswith("\n"))
78+
returncode = getattr(self, method_name)(*args, **env_dict)
7579

7680
# Exit with the same return code as the shell command
7781
if returncode != 0:
Lines changed: 126 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import asyncio
12
import os
3+
import signal
24
import subprocess
35
from dataclasses import dataclass, field
6+
from typing import AsyncGenerator
47

58
from jumpstarter.driver import Driver, export
69

@@ -27,41 +30,38 @@ def get_methods(self) -> list[str]:
2730
return methods
2831

2932
@export
30-
def call_method(self, method: str, env, *args):
33+
async def call_method(self, method: str, env, *args) -> AsyncGenerator[tuple[str, str, int | None], None]:
34+
"""
35+
Execute a shell method with live streaming output.
36+
Yields (stdout_chunk, stderr_chunk, returncode) tuples.
37+
returncode is None until the process completes, then it's the final return code.
38+
"""
3139
self.logger.info(f"calling {method} with args: {args} and kwargs as env: {env}")
3240
if method not in self.methods:
3341
raise ValueError(f"Method '{method}' not found in available methods: {list(self.methods.keys())}")
3442
script = self.methods[method]
3543
self.logger.debug(f"running script: {script}")
44+
3645
try:
37-
result = self._run_inline_shell_script(method, script, *args, env_vars=env)
38-
if result.returncode != 0:
39-
self.logger.info(f"{method} return code: {result.returncode}")
40-
if result.stderr != "":
41-
stderr = result.stderr.rstrip("\n")
42-
self.logger.debug(f"{method} stderr:\n{stderr}")
43-
if result.stdout != "":
44-
stdout = result.stdout.rstrip("\n")
45-
self.logger.debug(f"{method} stdout:\n{stdout}")
46-
return result.stdout, result.stderr, result.returncode
46+
async for stdout_chunk, stderr_chunk, returncode in self._run_inline_shell_script(
47+
method, script, *args, env_vars=env
48+
):
49+
if stdout_chunk:
50+
self.logger.debug(f"{method} stdout:\n{stdout_chunk.rstrip()}")
51+
if stderr_chunk:
52+
self.logger.debug(f"{method} stderr:\n{stderr_chunk.rstrip()}")
53+
54+
if returncode is not None and returncode != 0:
55+
self.logger.info(f"{method} return code: {returncode}")
56+
57+
yield stdout_chunk, stderr_chunk, returncode
4758
except subprocess.TimeoutExpired as e:
4859
self.logger.error(f"Timeout expired while running {method}: {e}")
49-
return "", f"Timeout expired while running {method}: {e}", 199
50-
51-
def _run_inline_shell_script(self, method, script, *args, env_vars=None):
52-
"""
53-
Run the given shell script (as a string) with optional arguments and
54-
environment variables. Returns a CompletedProcess with stdout, stderr, and returncode.
55-
56-
:param script: The shell script contents as a string.
57-
:param args: Arguments to pass to the script (mapped to $1, $2, etc. in the script).
58-
:param env_vars: A dict of environment variables to make available to the script.
59-
60-
:return: A subprocess.CompletedProcess object (Python 3.5+).
61-
"""
60+
yield "", f"\nTimeout expired while running {method}: {e}\n", 199
6261

62+
def _validate_script_params(self, script, args, env_vars):
63+
"""Validate script parameters and return combined environment."""
6364
# Merge parent environment with the user-supplied env_vars
64-
# so that we don't lose existing environment variables.
6565
combined_env = os.environ.copy()
6666
if env_vars:
6767
# Validate environment variable names
@@ -82,16 +82,108 @@ def _run_inline_shell_script(self, method, script, *args, env_vars=None):
8282
if self.cwd and not os.path.isdir(self.cwd):
8383
raise ValueError(f"Working directory does not exist: {self.cwd}")
8484

85+
return combined_env
86+
87+
async def _read_process_output(self, process, read_all=False):
88+
"""Read data from stdout and stderr streams.
89+
90+
:param process: The subprocess to read from
91+
:param read_all: If True, read all remaining data. If False, read with timeout.
92+
:return: Tuple of (stdout_data, stderr_data)
93+
"""
94+
stdout_data = ""
95+
stderr_data = ""
96+
97+
# Read from stdout
98+
if process.stdout:
99+
try:
100+
if read_all:
101+
chunk = await process.stdout.read()
102+
else:
103+
chunk = await asyncio.wait_for(process.stdout.read(1024), timeout=0.01)
104+
if chunk:
105+
stdout_data = chunk.decode('utf-8', errors='replace')
106+
except (asyncio.TimeoutError, Exception):
107+
pass
108+
109+
# Read from stderr
110+
if process.stderr:
111+
try:
112+
if read_all:
113+
chunk = await process.stderr.read()
114+
else:
115+
chunk = await asyncio.wait_for(process.stderr.read(1024), timeout=0.01)
116+
if chunk:
117+
stderr_data = chunk.decode('utf-8', errors='replace')
118+
except (asyncio.TimeoutError, Exception):
119+
pass
120+
121+
return stdout_data, stderr_data
122+
123+
async def _run_inline_shell_script(
124+
self, method, script, *args, env_vars=None
125+
) -> AsyncGenerator[tuple[str, str, int | None], None]:
126+
"""
127+
Run the given shell script with live streaming output.
128+
129+
:param method: The method name (for logging).
130+
:param script: The shell script contents as a string.
131+
:param args: Arguments to pass to the script (mapped to $1, $2, etc. in the script).
132+
:param env_vars: A dict of environment variables to make available to the script.
133+
134+
:yields: Tuples of (stdout_chunk, stderr_chunk, returncode).
135+
returncode is None until the process completes.
136+
"""
137+
combined_env = self._validate_script_params(script, args, env_vars)
85138
cmd = self.shell + [script, method] + list(args)
86139

87-
# Run the command
88-
result = subprocess.run(
89-
cmd,
90-
capture_output=True, # Captures stdout and stderr
91-
text=True, # Returns stdout/stderr as strings (not bytes)
92-
env=combined_env, # Pass our merged environment
93-
cwd=self.cwd, # Run in the working directory (if set)
94-
timeout=self.timeout,
140+
# Start the process with pipes for streaming and new process group
141+
process = await asyncio.create_subprocess_exec(
142+
*cmd,
143+
stdout=asyncio.subprocess.PIPE,
144+
stderr=asyncio.subprocess.PIPE,
145+
env=combined_env,
146+
cwd=self.cwd,
147+
start_new_session=True, # Create new process group
95148
)
96149

97-
return result
150+
# Create a task to monitor the process timeout
151+
start_time = asyncio.get_event_loop().time()
152+
153+
# Read output in real-time
154+
while process.returncode is None:
155+
self.logger.debug(f"running {method} with cmd: {cmd} and env: {combined_env} and args: {args}")
156+
if asyncio.get_event_loop().time() - start_time > self.timeout:
157+
# Send SIGTERM to entire process group for graceful termination
158+
try:
159+
os.killpg(process.pid, signal.SIGTERM)
160+
except (ProcessLookupError, OSError):
161+
# Process group might already be gone
162+
pass
163+
try:
164+
await asyncio.wait_for(process.wait(), timeout=5.0)
165+
except asyncio.TimeoutError:
166+
try:
167+
os.killpg(process.pid, signal.SIGKILL)
168+
self.logger.warning(f"SIGTERM failed to terminate {process.pid}, sending SIGKILL")
169+
except (ProcessLookupError, OSError):
170+
pass
171+
raise subprocess.TimeoutExpired(cmd, self.timeout) from None
172+
173+
try:
174+
stdout_data, stderr_data = await self._read_process_output(process, read_all=False)
175+
176+
# Yield any data we got
177+
if stdout_data or stderr_data:
178+
yield stdout_data, stderr_data, None
179+
180+
# Small delay to prevent busy waiting
181+
await asyncio.sleep(0.1)
182+
183+
except Exception:
184+
break
185+
186+
# Process completed, get return code and final output
187+
returncode = process.returncode
188+
remaining_stdout, remaining_stderr = await self._read_process_output(process, read_all=True)
189+
yield remaining_stdout, remaining_stderr, returncode

packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
1+
12
import pytest
23

34
from .driver import Shell
45
from jumpstarter.common.utils import serve
56

67

8+
def _collect_streaming_output(client, method_name, env_vars=None, *args):
9+
"""Helper function to collect streaming output for testing"""
10+
stdout_parts = []
11+
stderr_parts = []
12+
final_returncode = None
13+
14+
env_vars = env_vars or {}
15+
for stdout_chunk, stderr_chunk, returncode in client.streamingcall("call_method", method_name, env_vars, *args):
16+
if stdout_chunk:
17+
stdout_parts.append(stdout_chunk)
18+
if stderr_chunk:
19+
stderr_parts.append(stderr_chunk)
20+
if returncode is not None:
21+
final_returncode = returncode
22+
23+
return "".join(stdout_parts), "".join(stderr_parts), final_returncode
24+
25+
726
@pytest.fixture
827
def client():
928
instance = Shell(
@@ -21,23 +40,38 @@ def client():
2140

2241

2342
def test_normal_args(client):
24-
assert client.echo("hello") == ("hello\n", "", 0)
43+
stdout, stderr, returncode = _collect_streaming_output(client, "echo", {}, "hello")
44+
assert stdout == "hello\n"
45+
assert stderr == ""
46+
assert returncode == 0
2547

2648

2749
def test_env_vars(client):
28-
assert client.env(ENV1="world") == ("world\n", "", 0)
50+
stdout, stderr, returncode = _collect_streaming_output(client, "env", {"ENV1": "world"})
51+
assert stdout == "world\n"
52+
assert stderr == ""
53+
assert returncode == 0
2954

3055

3156
def test_multi_line_scripts(client):
32-
assert client.multi_line("a", "b", "c") == ("a\nb\nc\n", "", 0)
57+
stdout, stderr, returncode = _collect_streaming_output(client, "multi_line", {}, "a", "b", "c")
58+
assert stdout == "a\nb\nc\n"
59+
assert stderr == ""
60+
assert returncode == 0
3361

3462

3563
def test_return_codes(client):
36-
assert client.exit1() == ("", "", 1)
64+
stdout, stderr, returncode = _collect_streaming_output(client, "exit1")
65+
assert stdout == ""
66+
assert stderr == ""
67+
assert returncode == 1
3768

3869

3970
def test_stderr(client):
40-
assert client.stderr("error") == ("", "error\n", 0)
71+
stdout, stderr, returncode = _collect_streaming_output(client, "stderr", {}, "error")
72+
assert stdout == ""
73+
assert stderr == "error\n"
74+
assert returncode == 0
4175

4276

4377
def test_unknown_method(client):

packages/jumpstarter-driver-shell/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ readme = "README.md"
66
authors = [{ name = "Miguel Angel Ajo", email = "miguelangel@ajo.es" }]
77
requires-python = ">=3.11"
88
license = "Apache-2.0"
9-
dependencies = ["anyio>=4.6.2.post1", "jumpstarter", "click>=8.1.7.2"]
9+
dependencies = ["anyio>=4.6.2.post1", "jumpstarter", "click>=8.1.8"]
1010

1111
[project.entry-points."jumpstarter.drivers"]
1212
Shell = "jumpstarter_driver_shell.driver:Shell"

0 commit comments

Comments
 (0)