Skip to content

Commit b919ced

Browse files
live stream of command output in Shell driver
streams stdout and stderr for long-running methods
1 parent cfd5b99 commit b919ced

File tree

3 files changed

+165
-50
lines changed

3 files changed

+165
-50
lines changed

packages/jumpstarter-driver-shell/jumpstarter_driver_shell/client.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from dataclasses import dataclass
23

34
import click
@@ -13,8 +14,8 @@ class ShellClient(DriverClient):
1314
Client interface for Shell driver.
1415
1516
This client dynamically checks that the method is configured
16-
on the driver, and if it is, it will call it and get the results
17-
in the form of (stdout, stderr, returncode).
17+
on the driver, and if it is, it will call it with live streaming output.
18+
Output chunks are displayed as they arrive.
1819
"""
1920

2021
def _check_method_exists(self, method):
@@ -26,7 +27,17 @@ def _check_method_exists(self, method):
2627
## capture any method calls dynamically
2728
def __getattr__(self, name):
2829
self._check_method_exists(name)
29-
return lambda *args, **kwargs: tuple(self.call("call_method", name, kwargs, *args))
30+
def execute(*args, **kwargs):
31+
returncode = 0
32+
for stdout, stderr, code in self.streamingcall("call_method", name, kwargs, *args):
33+
if stdout:
34+
print(stdout, end='', flush=True)
35+
if stderr:
36+
print(stderr, end='', file=sys.stderr, flush=True)
37+
if code is not None:
38+
returncode = code
39+
return returncode
40+
return execute
3041

3142
def cli(self):
3243
"""Create CLI interface for dynamically configured shell methods"""
@@ -61,14 +72,7 @@ def method_command(args, env):
6172
else:
6273
raise click.BadParameter(f"Invalid --env value '{env_var}'. Use KEY=VALUE.")
6374

64-
# Call the method
65-
stdout, stderr, returncode = self.call("call_method", method_name, env_dict, *args)
66-
67-
# Display results
68-
if stdout:
69-
click.echo(stdout, nl=not stdout.endswith("\n"))
70-
if stderr:
71-
click.echo(stderr, err=True, nl=not stderr.endswith("\n"))
75+
returncode = getattr(self, method_name)(*args, **env_dict)
7276

7377
# Exit with the same return code as the shell command
7478
if returncode != 0:
Lines changed: 111 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import os
23
import subprocess
34
from dataclasses import dataclass, field
5+
from typing import AsyncGenerator
46

57
from jumpstarter.driver import Driver, export
68

@@ -27,41 +29,38 @@ def get_methods(self) -> list[str]:
2729
return methods
2830

2931
@export
30-
def call_method(self, method: str, env, *args):
32+
async def call_method(self, method: str, env, *args) -> AsyncGenerator[tuple[str, str, int | None], None]:
33+
"""
34+
Execute a shell method with live streaming output.
35+
Yields (stdout_chunk, stderr_chunk, returncode) tuples.
36+
returncode is None until the process completes, then it's the final return code.
37+
"""
3138
self.logger.info(f"calling {method} with args: {args} and kwargs as env: {env}")
3239
if method not in self.methods:
3340
raise ValueError(f"Method '{method}' not found in available methods: {list(self.methods.keys())}")
3441
script = self.methods[method]
3542
self.logger.debug(f"running script: {script}")
43+
3644
try:
37-
result = self._run_inline_shell_script(method, script, *args, env_vars=env)
38-
if result.returncode != 0:
39-
self.logger.info(f"{method} return code: {result.returncode}")
40-
if result.stderr != "":
41-
stderr = result.stderr.rstrip("\n")
42-
self.logger.debug(f"{method} stderr:\n{stderr}")
43-
if result.stdout != "":
44-
stdout = result.stdout.rstrip("\n")
45-
self.logger.debug(f"{method} stdout:\n{stdout}")
46-
return result.stdout, result.stderr, result.returncode
45+
async for stdout_chunk, stderr_chunk, returncode in self._run_inline_shell_script(
46+
method, script, *args, env_vars=env
47+
):
48+
if stdout_chunk:
49+
self.logger.debug(f"{method} stdout:\n{stdout_chunk.rstrip()}")
50+
if stderr_chunk:
51+
self.logger.debug(f"{method} stderr:\n{stderr_chunk.rstrip()}")
52+
53+
if returncode is not None and returncode != 0:
54+
self.logger.info(f"{method} return code: {returncode}")
55+
56+
yield stdout_chunk, stderr_chunk, returncode
4757
except subprocess.TimeoutExpired as e:
4858
self.logger.error(f"Timeout expired while running {method}: {e}")
49-
return "", f"Timeout expired while running {method}: {e}", 199
50-
51-
def _run_inline_shell_script(self, method, script, *args, env_vars=None):
52-
"""
53-
Run the given shell script (as a string) with optional arguments and
54-
environment variables. Returns a CompletedProcess with stdout, stderr, and returncode.
55-
56-
:param script: The shell script contents as a string.
57-
:param args: Arguments to pass to the script (mapped to $1, $2, etc. in the script).
58-
:param env_vars: A dict of environment variables to make available to the script.
59-
60-
:return: A subprocess.CompletedProcess object (Python 3.5+).
61-
"""
59+
yield "", f"\nTimeout expired while running {method}: {e}\n", 199
6260

61+
def _validate_script_params(self, script, args, env_vars):
62+
"""Validate script parameters and return combined environment."""
6363
# Merge parent environment with the user-supplied env_vars
64-
# so that we don't lose existing environment variables.
6564
combined_env = os.environ.copy()
6665
if env_vars:
6766
# Validate environment variable names
@@ -82,16 +81,94 @@ def _run_inline_shell_script(self, method, script, *args, env_vars=None):
8281
if self.cwd and not os.path.isdir(self.cwd):
8382
raise ValueError(f"Working directory does not exist: {self.cwd}")
8483

84+
return combined_env
85+
86+
async def _read_process_output(self, process, read_all=False):
87+
"""Read data from stdout and stderr streams.
88+
89+
:param process: The subprocess to read from
90+
:param read_all: If True, read all remaining data. If False, read with timeout.
91+
:return: Tuple of (stdout_data, stderr_data)
92+
"""
93+
stdout_data = ""
94+
stderr_data = ""
95+
96+
# Read from stdout
97+
if process.stdout:
98+
try:
99+
if read_all:
100+
chunk = await process.stdout.read()
101+
else:
102+
chunk = await asyncio.wait_for(process.stdout.read(1024), timeout=0.01)
103+
if chunk:
104+
stdout_data = chunk.decode('utf-8', errors='replace')
105+
except (asyncio.TimeoutError, Exception):
106+
pass
107+
108+
# Read from stderr
109+
if process.stderr:
110+
try:
111+
if read_all:
112+
chunk = await process.stderr.read()
113+
else:
114+
chunk = await asyncio.wait_for(process.stderr.read(1024), timeout=0.01)
115+
if chunk:
116+
stderr_data = chunk.decode('utf-8', errors='replace')
117+
except (asyncio.TimeoutError, Exception):
118+
pass
119+
120+
return stdout_data, stderr_data
121+
122+
async def _run_inline_shell_script(
123+
self, method, script, *args, env_vars=None
124+
) -> AsyncGenerator[tuple[str, str, int | None], None]:
125+
"""
126+
Run the given shell script with live streaming output.
127+
128+
:param method: The method name (for logging).
129+
:param script: The shell script contents as a string.
130+
:param args: Arguments to pass to the script (mapped to $1, $2, etc. in the script).
131+
:param env_vars: A dict of environment variables to make available to the script.
132+
133+
:yields: Tuples of (stdout_chunk, stderr_chunk, returncode).
134+
returncode is None until the process completes.
135+
"""
136+
combined_env = self._validate_script_params(script, args, env_vars)
85137
cmd = self.shell + [script, method] + list(args)
86138

87-
# Run the command
88-
result = subprocess.run(
89-
cmd,
90-
capture_output=True, # Captures stdout and stderr
91-
text=True, # Returns stdout/stderr as strings (not bytes)
92-
env=combined_env, # Pass our merged environment
93-
cwd=self.cwd, # Run in the working directory (if set)
94-
timeout=self.timeout,
139+
# Start the process with pipes for streaming
140+
process = await asyncio.create_subprocess_exec(
141+
*cmd,
142+
stdout=asyncio.subprocess.PIPE,
143+
stderr=asyncio.subprocess.PIPE,
144+
env=combined_env,
145+
cwd=self.cwd,
95146
)
96147

97-
return result
148+
# Create a task to monitor the process timeout
149+
start_time = asyncio.get_event_loop().time()
150+
151+
# Read output in real-time
152+
while process.returncode is None:
153+
if asyncio.get_event_loop().time() - start_time > self.timeout:
154+
process.kill()
155+
await process.wait()
156+
raise subprocess.TimeoutExpired(cmd, self.timeout) from None
157+
158+
try:
159+
stdout_data, stderr_data = await self._read_process_output(process, read_all=False)
160+
161+
# Yield any data we got
162+
if stdout_data or stderr_data:
163+
yield stdout_data, stderr_data, None
164+
165+
# Small delay to prevent busy waiting
166+
await asyncio.sleep(0.1)
167+
168+
except Exception:
169+
break
170+
171+
# Process completed, get return code and final output
172+
returncode = process.returncode
173+
remaining_stdout, remaining_stderr = await self._read_process_output(process, read_all=True)
174+
yield remaining_stdout, remaining_stderr, returncode

packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,28 @@
1+
12
import pytest
23

34
from .driver import Shell
45
from jumpstarter.common.utils import serve
56

67

8+
def _collect_streaming_output(client, method_name, env_vars=None, *args):
9+
"""Helper function to collect streaming output for testing"""
10+
stdout_parts = []
11+
stderr_parts = []
12+
final_returncode = None
13+
14+
env_vars = env_vars or {}
15+
for stdout_chunk, stderr_chunk, returncode in client.streamingcall("call_method", method_name, env_vars, *args):
16+
if stdout_chunk:
17+
stdout_parts.append(stdout_chunk)
18+
if stderr_chunk:
19+
stderr_parts.append(stderr_chunk)
20+
if returncode is not None:
21+
final_returncode = returncode
22+
23+
return "".join(stdout_parts), "".join(stderr_parts), final_returncode
24+
25+
726
@pytest.fixture
827
def client():
928
instance = Shell(
@@ -21,23 +40,38 @@ def client():
2140

2241

2342
def test_normal_args(client):
24-
assert client.echo("hello") == ("hello\n", "", 0)
43+
stdout, stderr, returncode = _collect_streaming_output(client, "echo", {}, "hello")
44+
assert stdout == "hello\n"
45+
assert stderr == ""
46+
assert returncode == 0
2547

2648

2749
def test_env_vars(client):
28-
assert client.env(ENV1="world") == ("world\n", "", 0)
50+
stdout, stderr, returncode = _collect_streaming_output(client, "env", {"ENV1": "world"})
51+
assert stdout == "world\n"
52+
assert stderr == ""
53+
assert returncode == 0
2954

3055

3156
def test_multi_line_scripts(client):
32-
assert client.multi_line("a", "b", "c") == ("a\nb\nc\n", "", 0)
57+
stdout, stderr, returncode = _collect_streaming_output(client, "multi_line", {}, "a", "b", "c")
58+
assert stdout == "a\nb\nc\n"
59+
assert stderr == ""
60+
assert returncode == 0
3361

3462

3563
def test_return_codes(client):
36-
assert client.exit1() == ("", "", 1)
64+
stdout, stderr, returncode = _collect_streaming_output(client, "exit1")
65+
assert stdout == ""
66+
assert stderr == ""
67+
assert returncode == 1
3768

3869

3970
def test_stderr(client):
40-
assert client.stderr("error") == ("", "error\n", 0)
71+
stdout, stderr, returncode = _collect_streaming_output(client, "stderr", {}, "error")
72+
assert stdout == ""
73+
assert stderr == "error\n"
74+
assert returncode == 0
4175

4276

4377
def test_unknown_method(client):

0 commit comments

Comments
 (0)