Skip to content

Commit d72a0f7

Browse files
committed
Add startup wait time to stdio client
1 parent 689c54c commit d72a0f7

File tree

4 files changed

+67
-21
lines changed

4 files changed

+67
-21
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"sse-starlette>=1.6.1",
3131
"pydantic-settings>=2.5.2",
3232
"uvicorn>=0.23.1",
33+
"exceptiongroup>=1.2.2",
3334
]
3435

3536
[project.optional-dependencies]

src/mcp/client/stdio.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class StdioServerParameters(BaseModel):
8383

8484

8585
@asynccontextmanager
86-
async def stdio_client(server: StdioServerParameters):
86+
async def stdio_client(server: StdioServerParameters, startup_wait_time: float = 0.0):
8787
"""
8888
Client transport for stdio: this will connect to a server by spawning a
8989
process and communicating with it over stdin/stdout.
@@ -97,11 +97,17 @@ async def stdio_client(server: StdioServerParameters):
9797
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
9898
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
9999

100-
process = await anyio.open_process(
101-
[server.command, *server.args],
102-
env=server.env if server.env is not None else get_default_environment(),
103-
stderr=sys.stderr,
104-
)
100+
try:
101+
process = await anyio.open_process(
102+
[server.command, *server.args],
103+
env=server.env or get_default_environment(),
104+
stderr=sys.stderr,
105+
)
106+
except OSError as exc:
107+
raise RuntimeError(
108+
f"Failed to spawn process: {server.command} {server.args}. "
109+
f"Check that the binary exists and is executable."
110+
) from exc
105111

106112
async def stdout_reader():
107113
assert process.stdout, "Opened process is missing stdout"
@@ -144,10 +150,21 @@ async def stdin_writer():
144150
except anyio.ClosedResourceError:
145151
await anyio.lowlevel.checkpoint()
146152

147-
async with (
148-
anyio.create_task_group() as tg,
149-
process,
150-
):
153+
async def watch_process_exit():
154+
returncode = await process.wait()
155+
if returncode != 0:
156+
raise RuntimeError(
157+
f"Subprocess exited with code {returncode}. "
158+
f"Command: {server.command}, {server.args}"
159+
)
160+
161+
async with anyio.create_task_group() as tg, process:
151162
tg.start_soon(stdout_reader)
152163
tg.start_soon(stdin_writer)
164+
tg.start_soon(watch_process_exit)
165+
166+
if startup_wait_time > 0:
167+
with anyio.move_on_after(startup_wait_time):
168+
await anyio.sleep_forever()
169+
153170
yield read_stream, write_stream

tests/client/test_stdio.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1+
import re
12
import shutil
3+
import sys
4+
5+
if sys.version_info < (3, 11):
6+
from exceptiongroup import ExceptionGroup
7+
else:
8+
ExceptionGroup = ExceptionGroup
29

310
import pytest
411

@@ -14,7 +21,6 @@ async def test_stdio_client():
1421
server_parameters = StdioServerParameters(command=tee)
1522

1623
async with stdio_client(server_parameters) as (read_stream, write_stream):
17-
# Test sending and receiving messages
1824
messages = [
1925
JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")),
2026
JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})),
@@ -31,13 +37,31 @@ async def test_stdio_client():
3137
raise message
3238

3339
read_messages.append(message)
34-
if len(read_messages) == 2:
40+
if len(read_messages) == len(messages):
3541
break
3642

37-
assert len(read_messages) == 2
38-
assert read_messages[0] == JSONRPCMessage(
39-
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
40-
)
41-
assert read_messages[1] == JSONRPCMessage(
42-
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
43-
)
43+
assert read_messages == messages
44+
45+
46+
@pytest.mark.anyio
47+
async def test_stdio_client_spawn_failure():
48+
server_parameters = StdioServerParameters(command="/does/not/exist")
49+
50+
with pytest.raises(RuntimeError, match="Failed to spawn process"):
51+
async with stdio_client(server_parameters):
52+
pytest.fail("Should never be reached.")
53+
54+
55+
@pytest.mark.anyio
56+
async def test_stdio_client_nonzero_exit():
57+
server_parameters = StdioServerParameters(
58+
command="python", args=["-c", "import sys; sys.exit(2)"]
59+
)
60+
61+
with pytest.raises(ExceptionGroup) as eg_info:
62+
async with stdio_client(server_parameters, startup_wait_time=0.2):
63+
pytest.fail("Should never be reached.")
64+
65+
exc = eg_info.value.exceptions[0]
66+
assert isinstance(exc, RuntimeError)
67+
assert re.search(r"exited with code 2", str(exc))

uv.lock

Lines changed: 6 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)