Skip to content

Commit 728ec85

Browse files
committed
Fix issue #428: ClientSession.initialize gets stuck if the MCP server process exits
1 parent c2ca8e0 commit 728ec85

File tree

3 files changed

+101
-6
lines changed

3 files changed

+101
-6
lines changed

src/mcp/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from .client.session import ClientSession
2-
from .client.stdio import StdioServerParameters, stdio_client
2+
from .client.stdio import (
3+
ProcessTerminatedEarlyError,
4+
StdioServerParameters,
5+
stdio_client,
6+
)
37
from .server.session import ServerSession
48
from .server.stdio import stdio_server
59
from .shared.exceptions import McpError
@@ -101,6 +105,7 @@
101105
"ServerResult",
102106
"ServerSession",
103107
"SetLevelRequest",
108+
"ProcessTerminatedEarlyError",
104109
"StdioServerParameters",
105110
"StopReason",
106111
"SubscribeRequest",

src/mcp/client/stdio/__init__.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
terminate_windows_process,
1919
)
2020

21+
__all__ = [
22+
"ProcessTerminatedEarlyError",
23+
"StdioServerParameters",
24+
"stdio_client",
25+
"get_default_environment",
26+
]
27+
2128
# Environment variables to inherit by default
2229
DEFAULT_INHERITED_ENV_VARS = (
2330
[
@@ -38,6 +45,13 @@
3845
)
3946

4047

48+
class ProcessTerminatedEarlyError(Exception):
49+
"""Raised when a process terminates unexpectedly."""
50+
51+
def __init__(self, message: str):
52+
super().__init__(message)
53+
54+
4155
def get_default_environment() -> dict[str, str]:
4256
"""
4357
Returns a default environment object including only environment variables deemed
@@ -163,20 +177,60 @@ async def stdin_writer():
163177
except anyio.ClosedResourceError:
164178
await anyio.lowlevel.checkpoint()
165179

180+
process_error: str | None = None
181+
166182
async with (
167183
anyio.create_task_group() as tg,
168184
process,
169185
):
170186
tg.start_soon(stdout_reader)
171187
tg.start_soon(stdin_writer)
188+
189+
# Add a task to monitor the process and detect early termination
190+
async def monitor_process():
191+
nonlocal process_error
192+
try:
193+
await process.wait()
194+
# Only consider it an error if the process exits with a non-zero code
195+
# during normal operation (not when we explicitly terminate it)
196+
if process.returncode != 0 and not tg.cancel_scope.cancel_called:
197+
process_error = f"Process exited with code {process.returncode}."
198+
# Cancel the task group to stop other tasks
199+
tg.cancel_scope.cancel()
200+
except anyio.get_cancelled_exc_class():
201+
# Task was cancelled, which is expected when we're done
202+
pass
203+
204+
tg.start_soon(monitor_process)
205+
172206
try:
173207
yield read_stream, write_stream
174208
finally:
209+
# Set a flag to indicate we're explicitly terminating the process
210+
# This prevents the monitor_process from treating our termination
211+
# as an error when we explicitly terminate it
212+
tg.cancel_scope.cancel()
213+
214+
# Close all streams to prevent resource leaks
215+
await read_stream.aclose()
216+
await write_stream.aclose()
217+
await read_stream_writer.aclose()
218+
await write_stream_reader.aclose()
219+
175220
# Clean up process to prevent any dangling orphaned processes
176-
if sys.platform == "win32":
177-
await terminate_windows_process(process)
178-
else:
179-
process.terminate()
221+
try:
222+
if sys.platform == "win32":
223+
await terminate_windows_process(process)
224+
else:
225+
process.terminate()
226+
except ProcessLookupError:
227+
# Process has already exited, which is fine
228+
pass
229+
230+
if process_error:
231+
# Raise outside the task group so that the error is not wrapped in an
232+
# ExceptionGroup
233+
raise ProcessTerminatedEarlyError(process_error)
180234

181235

182236
def _get_executable_command(command: str) -> str:

tests/client/test_stdio.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import shutil
22

33
import pytest
4+
from anyio import fail_after
45

5-
from mcp.client.stdio import StdioServerParameters, stdio_client
6+
from mcp.client.session import ClientSession
7+
from mcp.client.stdio import (
8+
ProcessTerminatedEarlyError,
9+
StdioServerParameters,
10+
stdio_client,
11+
)
612
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
713

814
tee: str = shutil.which("tee") # type: ignore
15+
python: str = shutil.which("python") # type: ignore
916

1017

1118
@pytest.mark.anyio
@@ -41,3 +48,32 @@ async def test_stdio_client():
4148
assert read_messages[1] == JSONRPCMessage(
4249
root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})
4350
)
51+
52+
53+
@pytest.mark.anyio
54+
@pytest.mark.skipif(python is None, reason="could not find python command")
55+
async def test_initialize_with_exiting_server():
56+
"""
57+
Test that ClientSession.initialize raises an error if the server process exits.
58+
"""
59+
# Create a server that will exit during initialization
60+
server_params = StdioServerParameters(
61+
command="python",
62+
args=[
63+
"-c",
64+
"import sys; print('Error: Missing API key', file=sys.stderr); sys.exit(1)",
65+
],
66+
)
67+
68+
with pytest.raises(ProcessTerminatedEarlyError):
69+
try:
70+
# Set a timeout to avoid hanging indefinitely if the test fails
71+
with fail_after(5):
72+
async with stdio_client(server_params) as (read_stream, write_stream):
73+
# Create a client session
74+
session = ClientSession(read_stream, write_stream)
75+
76+
# This should fail because the server process has exited
77+
await session.initialize()
78+
except TimeoutError:
79+
pytest.fail("The connection hung and timed out.")

0 commit comments

Comments
 (0)