diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 0d3c372ce..e8d49382b 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,5 +1,9 @@ from .client.session import ClientSession -from .client.stdio import StdioServerParameters, stdio_client +from .client.stdio import ( + ProcessTerminatedEarlyError, + StdioServerParameters, + stdio_client, +) from .server.session import ServerSession from .server.stdio import stdio_server from .shared.exceptions import McpError @@ -101,6 +105,7 @@ "ServerResult", "ServerSession", "SetLevelRequest", + "ProcessTerminatedEarlyError", "StdioServerParameters", "StopReason", "SubscribeRequest", diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 83de57a2b..f518624a7 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -18,6 +18,13 @@ terminate_windows_process, ) +__all__ = [ + "ProcessTerminatedEarlyError", + "StdioServerParameters", + "stdio_client", + "get_default_environment", +] + # Environment variables to inherit by default DEFAULT_INHERITED_ENV_VARS = ( [ @@ -38,6 +45,13 @@ ) +class ProcessTerminatedEarlyError(Exception): + """Raised when a process terminates unexpectedly.""" + + def __init__(self, message: str): + super().__init__(message) + + def get_default_environment() -> dict[str, str]: """ Returns a default environment object including only environment variables deemed @@ -163,20 +177,60 @@ async def stdin_writer(): except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() + process_error: str | None = None + async with ( anyio.create_task_group() as tg, process, ): tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) + + # Add a task to monitor the process and detect early termination + async def monitor_process(): + nonlocal process_error + try: + await process.wait() + # Only consider it an error if the process exits with a non-zero code + # during normal operation (not when we explicitly terminate it) + if process.returncode != 0 and not tg.cancel_scope.cancel_called: + process_error = f"Process exited with code {process.returncode}." + # Cancel the task group to stop other tasks + tg.cancel_scope.cancel() + except anyio.get_cancelled_exc_class(): + # Task was cancelled, which is expected when we're done + pass + + tg.start_soon(monitor_process) + try: yield read_stream, write_stream finally: + # Set a flag to indicate we're explicitly terminating the process + # This prevents the monitor_process from treating our termination + # as an error when we explicitly terminate it + tg.cancel_scope.cancel() + + # Close all streams to prevent resource leaks + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + # Clean up process to prevent any dangling orphaned processes - if sys.platform == "win32": - await terminate_windows_process(process) - else: - process.terminate() + try: + if sys.platform == "win32": + await terminate_windows_process(process) + else: + process.terminate() + except ProcessLookupError: + # Process has already exited, which is fine + pass + + if process_error: + # Raise outside the task group so that the error is not wrapped in an + # ExceptionGroup + raise ProcessTerminatedEarlyError(process_error) def _get_executable_command(command: str) -> str: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 95747ffd1..55e11cbca 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,11 +1,18 @@ import shutil import pytest +from anyio import fail_after -from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.client.session import ClientSession +from mcp.client.stdio import ( + ProcessTerminatedEarlyError, + StdioServerParameters, + stdio_client, +) from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore +python: str = shutil.which("python") # type: ignore @pytest.mark.anyio @@ -41,3 +48,32 @@ async def test_stdio_client(): assert read_messages[1] == JSONRPCMessage( root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) ) + + +@pytest.mark.anyio +@pytest.mark.skipif(python is None, reason="could not find python command") +async def test_initialize_with_exiting_server(): + """ + Test that ClientSession.initialize raises an error if the server process exits. + """ + # Create a server that will exit during initialization + server_params = StdioServerParameters( + command="python", + args=[ + "-c", + "import sys; print('Error: Missing API key', file=sys.stderr); sys.exit(1)", + ], + ) + + with pytest.raises(ProcessTerminatedEarlyError): + try: + # Set a timeout to avoid hanging indefinitely if the test fails + with fail_after(5): + async with stdio_client(server_params) as (read_stream, write_stream): + # Create a client session + session = ClientSession(read_stream, write_stream) + + # This should fail because the server process has exited + await session.initialize() + except TimeoutError: + pytest.fail("The connection hung and timed out.")