diff --git a/README.md b/README.md index 412143a9f..a60850a11 100644 --- a/README.md +++ b/README.md @@ -635,6 +635,66 @@ mcp run server.py Note that `mcp run` or `mcp dev` only supports server using FastMCP and not the low-level server variant. +### Socket Transport + +Socket transport provides a simple and efficient communication channel between client and server, similar to stdio but without stdout pollution concerns. Unlike stdio transport which requires clean stdout for message passing, socket transport allows the server to freely use stdout for logging and other purposes. + +The workflow is: +1. Client creates a TCP server and gets an available port +2. Client starts the server process, passing the port number +3. Server connects back to the client's TCP server +4. Client and server exchange messages over the TCP connection +5. When done, client closes the connection and terminates the server process + +This design maintains the simplicity of stdio transport while providing more flexibility for server output handling. + +Example server setup: +```python +from mcp.server.fastmcp import FastMCP + +# Create server with socket transport configuration +mcp = FastMCP( + "SocketServer", + socket_host="127.0.0.1", # Optional, defaults to 127.0.0.1 + socket_port=3000, # Required when using socket transport +) + +# Run with socket transport +mcp.run(transport="socket") +``` + +Client usage: +```python +from mcp.client.session import ClientSession +from mcp.client.socket_transport import SocketServerParameters, socket_client + +# Create server parameters +params = SocketServerParameters( + command="python", # Server process to run + args=["server.py"], # Server script and arguments + # Port 0 means auto-assign an available port + port=0, # Optional, defaults to 0 (auto-assign) + host="127.0.0.1", # Optional, defaults to 127.0.0.1 +) + +# Connect to server (this will start the server process) +async with socket_client(params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + # Use the session... + await session.initialize() + result = await session.call_tool("echo", {"text": "Hello!"}) +``` + +The socket transport provides: +- Freedom to use stdout without affecting message transport +- Standard TCP socket-based communication +- Automatic port assignment for easy setup +- Connection retry logic for reliability +- Clean process lifecycle management +- Robust error handling + +For a complete example, see [`examples/fastmcp/socket_example.py`](examples/fastmcp/socket_example.py). + ### Streamable HTTP Transport > **Note**: Streamable HTTP transport is superseding SSE transport for production deployments. diff --git a/examples/fastmcp/socket_examples/README.md b/examples/fastmcp/socket_examples/README.md new file mode 100644 index 000000000..f7713e00c --- /dev/null +++ b/examples/fastmcp/socket_examples/README.md @@ -0,0 +1,79 @@ +# Socket Transport Examples + +This directory contains examples demonstrating the socket transport feature of FastMCP. Socket transport provides a simple and efficient communication channel between client and server, similar to stdio but without stdout pollution concerns. + +## Overview + +The socket transport works by: +1. Client creates a TCP server and gets an available port +2. Client starts the server process, passing the port number +3. Server connects back to the client's TCP server +4. Client and server exchange messages over the TCP connection +5. When done, client closes the connection and terminates the server process + +## Files + +- `client.py` - Example client that: + - Creates a TCP server + - Starts the server process + - Establishes MCP session + - Calls example tools + +- `server.py` - Example server that: + - Connects to client's TCP server + - Sets up FastMCP environment + - Provides example tools + - Demonstrates logging usage + +## Usage + +1. Run with auto-assigned port (recommended): +```bash +python client.py +``` + +2. Run with specific host and port: +```bash +python client.py --host localhost --port 3000 +``` + +3. Run server directly (for testing): +```bash +python server.py --name "Echo Server" --host localhost --port 3000 --log-level DEBUG +``` + +## Configuration + +### Client Options +- `--host` - Host to bind to (default: 127.0.0.1) +- `--port` - Port to use (default: 0 for auto-assign) + +### Server Options +- `--name` - Server name +- `--host` - Host to connect to +- `--port` - Port to connect to (required) +- `--log-level` - Logging level (DEBUG/INFO/WARNING/ERROR) + +## Implementation Details + +### Client Features +- Automatic port assignment +- Server process management +- Connection retry logic +- Error handling +- Clean shutdown + +### Server Features +- Connection retry logic +- Custom text encoding support +- Stdout/logging freedom +- Error handling +- Clean shutdown + +### Error Handling +The examples demonstrate handling of: +- Connection failures and retries +- Invalid JSON messages +- Text encoding errors +- Tool execution errors +- Process lifecycle management diff --git a/examples/fastmcp/socket_examples/client.py b/examples/fastmcp/socket_examples/client.py new file mode 100644 index 000000000..401a4c4d5 --- /dev/null +++ b/examples/fastmcp/socket_examples/client.py @@ -0,0 +1,155 @@ +""" +Example of using socket transport with FastMCP. + +This example demonstrates: +1. Creating a FastMCP server that uses socket transport +2. Creating a client that connects to the server using socket transport +3. Exchanging messages between client and server +4. Handling connection errors and retries +5. Using custom encoding and configuration +6. Verifying server process cleanup + +Usage: + python client.py [--host HOST] [--port PORT] [--log-level LEVEL] +""" + +import argparse +import asyncio +import logging +import sys +import psutil +from pathlib import Path + +from mcp.client.session import ClientSession +from mcp.client.socket_transport import SocketServerParameters, socket_client +from mcp.shared.exceptions import McpError + +# Set up logging +logger = logging.getLogger(__name__) + + +async def verify_process_cleanup(pid: int) -> bool: + """ + Verify if a process with given PID exists. + + Args: + pid: Process ID to check + + Returns: + bool: True if process does not exist (cleaned up), False if still running + """ + try: + process = psutil.Process(pid) + return False # Process still exists + except psutil.NoSuchProcess: + return True # Process has been cleaned up + + +async def main(host: str = "127.0.0.1", port: int = 0, log_level: str = "INFO"): + """ + Run the client which will start and connect to the server. + + Args: + host: The host to use for socket communication (default: 127.0.0.1) + port: The port to use for socket communication (default: 0 for auto-assign) + log_level: Logging level (default: INFO) + """ + # Configure logging + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + server_pid = None + try: + # Create server parameters with custom configuration + params = SocketServerParameters( + # The command to run the server + command=sys.executable, # Use current Python interpreter + # Arguments to start the server script + args=[ + str(Path(__file__).parent / "server.py"), # Updated path + "--name", + "Echo Server", + "--host", + host, + "--port", + str(port), + "--log-level", + log_level, + ], + # Socket configuration + host=host, + port=port, + # Optional: customize encoding (defaults shown) + encoding="utf-8", + encoding_error_handler="strict", + ) + + # Connect to server (this will start the server process) + async with socket_client(params) as (read_stream, write_stream): + # Create client session + async with ClientSession(read_stream, write_stream) as session: + try: + # Initialize the session + await session.initialize() + logger.info("Session initialized successfully") + + # Get server process PID for verification + result = await session.call_tool("get_pid_tool", {}) + server_pid = result.structuredContent["result"]["pid"] + logger.info(f"Server process PID: {server_pid}") + + # List available tools + tools = await session.list_tools() + logger.info(f"Available tools: {[t.name for t in tools.tools]}") + + # Call the echo tool with different inputs + messages = [ + "Hello from socket transport!", + "Testing special chars: 世界, мир, ♥", + "Testing long message: " + "x" * 1000, + ] + + for msg in messages: + try: + result = await session.call_tool("echo_tool", {"text": msg}) + logger.info(f"Echo result: {result}") + except McpError as e: + logger.error(f"Tool call failed: {e}") + + except McpError as e: + logger.error(f"Session error: {e}") + sys.exit(1) + + # After session ends, verify server process cleanup + if server_pid: + await asyncio.sleep(0.5) # Give some time for cleanup + is_cleaned = await verify_process_cleanup(server_pid) + if is_cleaned: + logger.info( + f"Server process (PID: {server_pid}) was successfully cleaned up" + ) + else: + logger.warning(f"Server process (PID: {server_pid}) is still running!") + + except Exception as e: + logger.error(f"Connection failed: {e}", exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Socket transport example client") + parser.add_argument("--host", default="127.0.0.1", help="Host to use") + parser.add_argument("--port", type=int, default=0, help="Port to use (0 for auto)") + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level", + ) + + args = parser.parse_args() + + # Run everything + asyncio.run(main(host=args.host, port=args.port, log_level=args.log_level)) diff --git a/examples/fastmcp/socket_examples/server.py b/examples/fastmcp/socket_examples/server.py new file mode 100644 index 000000000..a4d802ae6 --- /dev/null +++ b/examples/fastmcp/socket_examples/server.py @@ -0,0 +1,106 @@ +""" +Socket transport server script. + +This script demonstrates: +1. Creating a FastMCP server with socket transport +2. Connecting back to the client's socket using the provided port +3. Running the server until the connection is closed +4. Handling connection errors and encoding +5. Supporting command-line configuration + +Usage: + python server.py --name NAME [--host HOST] [--port PORT] + +Note: + This server is typically started by the client (client.py). + Direct execution is mainly for testing purposes. +""" + +import argparse +import logging +import os +import sys +from typing import Dict + +from mcp.server.fastmcp import FastMCP +from mcp.shared.exceptions import McpError + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def echo_tool(text: str) -> str: + """ + A simple echo tool that returns the input text. + + Args: + text: The text to echo back + + Returns: + The same text that was provided + """ + logger.info(f"Echo tool called with: {text}") + return text + + +async def get_pid_tool() -> Dict[str, int]: + """ + A tool that returns the server's process ID. + + Returns: + Dict[str, int]: A dictionary containing the server's process ID + """ + return {"pid": os.getpid()} + + +def main(): + """Parse arguments and run the server with socket transport.""" + parser = argparse.ArgumentParser(description="Socket transport server example") + parser.add_argument("--host", default="127.0.0.1", help="Host to connect to") + parser.add_argument("--port", type=int, required=True, help="Port to connect to") + parser.add_argument("--name", default="Socket Server", help="Server name") + parser.add_argument( + "--log-level", + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level", + ) + + args = parser.parse_args() + + # Set logging level from arguments + logging.getLogger().setLevel(args.log_level) + + try: + # Create FastMCP server with socket settings + server = FastMCP( + name=args.name, + socket_host=args.host, + socket_port=args.port, + ) + + # Add our tools + server.add_tool(echo_tool) + server.add_tool(get_pid_tool) + + logger.info(f"Starting server {args.name} with socket transport") + logger.info(f"Will connect to client at {args.host}:{args.port}") + + try: + # Use the socket transport + server.run(transport="socket") + except McpError as e: + logger.error(f"MCP error: {e}") + sys.exit(1) + except Exception as e: + logger.error(f"Server error: {e}") + sys.exit(1) + + except Exception as e: + logger.error(f"Failed to create server: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/mcp/client/socket_transport.py b/src/mcp/client/socket_transport.py new file mode 100644 index 000000000..6551eca26 --- /dev/null +++ b/src/mcp/client/socket_transport.py @@ -0,0 +1,244 @@ +""" +Socket Transport Module + +This module implements a socket-based transport for MCP that provides +1-to-1 client-server communication over TCP sockets. +""" + +import logging +import os +import sys +from contextlib import asynccontextmanager +from pathlib import Path +from typing import AsyncGenerator, TextIO + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import BaseModel, Field + +import mcp.types as types +from mcp.shared.message import SessionMessage + +logger = logging.getLogger(__name__) + + +class SocketServerParameters(BaseModel): + """Configuration parameters for socket-based transport.""" + + command: str + """The executable to run to start the server.""" + + args: list[str] = Field(default_factory=list) + """Command line arguments to pass to the executable.""" + + env: dict[str, str] | None = None + """ + The environment to use when spawning the process. + If not specified, the current environment will be used. + """ + + cwd: str | Path | None = None + """The working directory to use when spawning the process.""" + + host: str = Field(default="127.0.0.1") + """The host to bind to for socket communication.""" + + port: int = Field(default=0) + """ + The port to bind to for socket communication. + If 0, a random available port will be used. + """ + + encoding: str = "utf-8" + """The text encoding used when sending/receiving messages.""" + + encoding_error_handler: str = "strict" + """The text encoding error handler.""" + + connection_timeout: float = Field(default=5.0) + """Timeout in seconds for connection acceptance.""" + + +@asynccontextmanager +async def socket_client( + server: SocketServerParameters, errlog: TextIO = sys.stderr +) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + ], + None, +]: + """ + Client transport for socket-based communication. + + This will: + 1. Start a server process using the provided command + 2. Create a socket connection to that server + 3. Communicate using JSON-RPC messages over the socket connection + + Args: + server: Socket server parameters + errlog: Where to send server process stderr (defaults to sys.stderr) + + Yields: + A tuple containing: + - read_stream: Stream for reading messages from the server + - write_stream: Stream for sending messages to the server + + Raises: + TimeoutError: If connection acceptance times out + OSError: If process startup fails + Exception: For other errors + """ + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + # Create a TCP listener first to get the port + listener = await anyio.create_tcp_listener( + local_host=server.host, local_port=server.port + ) + actual_port = listener.extra(anyio.abc.SocketAttribute.local_port) + logger.info(f"Listening on port {actual_port}") + + try: + # Start the server process with the port as an argument + process_args = [*server.args, "--port", str(actual_port)] + process = await anyio.open_process( + [server.command, *process_args], + env=server.env or os.environ, + stderr=errlog, + cwd=server.cwd, + ) + + try: + # Accept connection from the server with timeout + stream = None + connection_event = anyio.Event() + + async def handle_connection(client_stream): + nonlocal stream + stream = client_stream + logger.info(f"Accepted connection from server") + connection_event.set() + + async def run_listener(): + try: + async with listener: + await listener.serve(handle_connection) + except anyio.get_cancelled_exc_class(): + # Normal cancellation, just exit + pass + except Exception as e: + logger.error(f"Error in listener: {e}") + raise + + async def socket_reader(): + """Reads messages from the socket and forwards them to read_stream.""" + try: + async with read_stream_writer: + buffer = "" + async for data in stream: + text = data.decode( + server.encoding, server.encoding_error_handler + ) + lines = (buffer + text).split("\n") + buffer = lines.pop() + + for line in lines: + try: + message = types.JSONRPCMessage.model_validate_json( + line + ) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + except Exception as exc: + await read_stream_writer.send(exc) + continue + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + except anyio.get_cancelled_exc_class(): + # Handle cancellation gracefully + logger.info("Socket reader cancelled") + return + except Exception as e: + logger.error(f"Error in socket reader: {e}") + raise + + async def socket_writer(): + """Reads messages from write_stream and sends them over the socket.""" + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) + data = (json + "\n").encode( + server.encoding, server.encoding_error_handler + ) + await stream.send(data) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + except anyio.get_cancelled_exc_class(): + # Handle cancellation gracefully + logger.info("Socket writer cancelled") + return + except Exception as e: + logger.error(f"Error in socket writer: {e}") + raise + + async with anyio.create_task_group() as tg: + # Start the listener task + tg.start_soon(run_listener) + + # Wait for connection with timeout + with anyio.fail_after(server.connection_timeout): + await connection_event.wait() + + # Start reader and writer tasks + tg.start_soon(socket_reader) + tg.start_soon(socket_writer) + + try: + async with process, stream: + yield read_stream, write_stream + finally: + # Cancel all tasks and clean up with timeout + tg.cancel_scope.cancel() + + # Force cleanup with timeout to prevent hanging + try: + with anyio.fail_after(5.0): # 5 second timeout for cleanup + # Clean up process to prevent any dangling orphaned processes + try: + process.terminate() + except ProcessLookupError: + # Process already exited, which is fine + pass + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + except anyio.get_cancelled_exc_class(): + # If cleanup times out, force kill the process + logger.warning("Cleanup timed out, force killing process") + try: + process.kill() + except ProcessLookupError: + pass + + finally: + # Clean up process + if process.returncode is None: + process.terminate() + await process.aclose() + + finally: + # Clean up listener + await listener.aclose() diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 956a8aa78..77bf6c0af 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -42,6 +42,7 @@ from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.session import ServerSession, ServerSessionT +from mcp.server.socket_transport import socket_server from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore @@ -90,6 +91,12 @@ class Settings(BaseSettings, Generic[LifespanResultT]): message_path: str = "/messages/" streamable_http_path: str = "/mcp" + # Socket settings + socket_host: str = "127.0.0.1" + socket_port: int = 0 # Use 0 for placeholder, must be set when running on socket transport + socket_encoding: str = "utf-8" + socket_encoding_error_handler: str = "strict" + # StreamableHTTP settings json_response: bool = False stateless_http: bool = False # If True, uses true stateless mode (new transport per request) @@ -208,16 +215,16 @@ def session_manager(self) -> StreamableHTTPSessionManager: def run( self, - transport: Literal["stdio", "sse", "streamable-http"] = "stdio", + transport: Literal["stdio", "sse", "streamable-http", "socket"] = "stdio", mount_path: str | None = None, ) -> None: """Run the FastMCP server. Note this is a synchronous function. Args: - transport: Transport protocol to use ("stdio", "sse", or "streamable-http") + transport: Transport protocol to use ("stdio", "sse", "streamable-http", or "socket") mount_path: Optional mount path for SSE transport """ - TRANSPORTS = Literal["stdio", "sse", "streamable-http"] + TRANSPORTS = Literal["stdio", "sse", "streamable-http", "socket"] if transport not in TRANSPORTS.__args__: # type: ignore raise ValueError(f"Unknown transport: {transport}") @@ -228,6 +235,8 @@ def run( anyio.run(lambda: self.run_sse_async(mount_path)) case "streamable-http": anyio.run(self.run_streamable_http_async) + case "socket": + anyio.run(self.run_socket_async) def _setup_handlers(self) -> None: """Set up core MCP protocol handlers.""" @@ -636,6 +645,20 @@ async def run_stdio_async(self) -> None: self._mcp_server.create_initialization_options(), ) + async def run_socket_async(self) -> None: + """Run the server using socket transport.""" + async with socket_server( + host=self.settings.socket_host, + port=self.settings.socket_port, + encoding=self.settings.socket_encoding, + encoding_error_handler=self.settings.socket_encoding_error_handler, + ) as (read_stream, write_stream): + await self._mcp_server.run( + read_stream, + write_stream, + self._mcp_server.create_initialization_options(), + ) + async def run_sse_async(self, mount_path: str | None = None) -> None: """Run the server using SSE transport.""" import uvicorn diff --git a/src/mcp/server/socket_transport.py b/src/mcp/server/socket_transport.py new file mode 100644 index 000000000..6bb9ed3bb --- /dev/null +++ b/src/mcp/server/socket_transport.py @@ -0,0 +1,165 @@ +""" +Socket Transport Server Module + +This module implements a socket-based transport for MCP that provides +1-to-1 client-server communication over TCP sockets. +""" + +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +import mcp.types as types +from mcp.shared.message import SessionMessage + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def socket_server( + port: int, + host: str = "127.0.0.1", + encoding: str = "utf-8", + encoding_error_handler: str = "strict", +) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + ], + None, +]: + """ + Server transport for socket-based communication. + + This will connect to a client's TCP socket and communicate using + JSON-RPC messages over the socket connection. + + Args: + port: The port to connect to (required, must not be 0) + host: The host to connect to (defaults to "127.0.0.1") + encoding: Text encoding to use (defaults to "utf-8") + encoding_error_handler: Text encoding error handler (defaults to "strict") + + Yields: + A tuple containing: + - read_stream: Stream for reading messages from the client + - write_stream: Stream for sending messages to the client + + Raises: + ValueError: If port is 0 + """ + if port == 0: + raise ValueError( + "Port cannot be 0 when connecting to client. A specific port must be provided." + ) + + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + try: + # Connect to the client's TCP server with retry logic + stream = None + for attempt in range(5): # Try 5 times + try: + stream = await anyio.connect_tcp(host, port) + logger.info(f"Connected to client at {host}:{port}") + break + except OSError as e: + if attempt == 4: # Last attempt + logger.error(f"Failed to connect to client at {host}:{port}") + raise e + logger.info(f"Connection attempt {attempt + 1} failed, retrying...") + await anyio.sleep(1) # Wait a bit before retrying + + if not stream: + raise RuntimeError("Failed to connect to client") + + async def socket_reader(): + """Reads messages from the socket and forwards them to read_stream.""" + try: + async with read_stream_writer: + buffer = "" + async for data in stream: + text = data.decode(encoding, encoding_error_handler) + lines = (buffer + text).split("\n") + buffer = lines.pop() + + for line in lines: + try: + message = types.JSONRPCMessage.model_validate_json(line) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + except Exception as exc: + await read_stream_writer.send(exc) + continue + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + except anyio.get_cancelled_exc_class(): + # Handle cancellation gracefully + logger.info("Socket reader cancelled") + return + except Exception as e: + logger.error(f"Error in socket reader: {e}") + raise + finally: + await stream.aclose() + + async def socket_writer(): + """Reads messages from write_stream and sends them over the socket.""" + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) + data = (json + "\n").encode(encoding, encoding_error_handler) + await stream.send(data) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + except anyio.get_cancelled_exc_class(): + # Handle cancellation gracefully + logger.info("Socket writer cancelled") + return + except Exception as e: + logger.error(f"Error in socket writer: {e}") + raise + finally: + await stream.aclose() + + async with anyio.create_task_group() as tg: + tg.start_soon(socket_reader) + tg.start_soon(socket_writer) + + try: + yield read_stream, write_stream + finally: + # Cancel all tasks and clean up with timeout + tg.cancel_scope.cancel() + + # Force cleanup with timeout to prevent hanging + try: + with anyio.fail_after(5.0): # 5 second timeout for cleanup + await stream.aclose() + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + except anyio.get_cancelled_exc_class(): + # If cleanup times out, log warning + logger.warning("Server cleanup timed out") + + except Exception: + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + raise diff --git a/tests/client/test_socket_transport.py b/tests/client/test_socket_transport.py new file mode 100644 index 000000000..a51fdc785 --- /dev/null +++ b/tests/client/test_socket_transport.py @@ -0,0 +1,469 @@ +""" +Tests for socket transport functionality. + +This module tests both client and server sides of the socket transport, +including error handling, encoding, and FastMCP integration. +""" + +import shutil +import socket + +import anyio +import pytest + +from mcp.client.session import ClientSession +from mcp.client.socket_transport import SocketServerParameters, socket_client +from mcp.shared.exceptions import McpError +from mcp.shared.message import SessionMessage +from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse + +python = shutil.which("python") or "python" + + +@pytest.mark.anyio +async def test_socket_context_manager_exiting(): + """Test that the socket client context manager exits cleanly.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + "import socket; s=socket.socket(); s.connect(('127.0.0.1', int(__import__('sys').argv[2]))); s.close()", + ], + ) + async with socket_client(server_params) as (_, _): + pass + + +@pytest.mark.anyio +async def test_socket_client(): + """Test basic message sending and receiving over socket transport.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) +while True: + data = s.recv(1024) + if not data: + break + s.send(data) +s.close() + """, + ], + ) + + async with socket_client(server_params) as (read_stream, write_stream): + # Test sending and receiving messages + messages = [ + JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), + JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), + ] + + async with write_stream: + for message in messages: + session_message = SessionMessage(message) + await write_stream.send(session_message) + + read_messages = [] + async with read_stream: + async for message in read_stream: + if isinstance(message, Exception): + raise message + + read_messages.append(message.message) + if len(read_messages) == 2: + break + + assert len(read_messages) == 2 + assert read_messages[0] == JSONRPCMessage( + root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + ) + assert read_messages[1] == JSONRPCMessage( + root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) + ) + + +@pytest.mark.anyio +async def test_socket_client_bad_path(): + """Check that the connection doesn't hang if process errors.""" + server_params = SocketServerParameters( + command="python", args=["-c", "non-existent-file.py"] + ) + with pytest.raises(Exception) as exc_info: + async with socket_client(server_params) as (read_stream, write_stream): + pass + + # The error should be a TimeoutError wrapped in an ExceptionGroup + assert isinstance(exc_info.value, ExceptionGroup) + assert isinstance(exc_info.value.exceptions[0], TimeoutError) + + +@pytest.mark.anyio +async def test_socket_client_nonexistent_command(): + """Test that socket_client raises an error for non-existent commands.""" + # Create a server with a non-existent command + server_params = SocketServerParameters( + command="/path/to/nonexistent/command", + args=["--help"], + ) + + # Should raise an error when trying to start the process + with pytest.raises(OSError) as exc_info: + async with socket_client(server_params) as (_, _): + pass + + # The error should indicate the command was not found + error_message = str(exc_info.value) + assert ( + "nonexistent" in error_message + or "not found" in error_message.lower() + or "cannot find the file" in error_message.lower() # Windows error message + or "no such file" in error_message.lower() # Unix/macOS error message + ) + + +@pytest.mark.anyio +async def test_socket_client_connection_timeout(): + """Test that socket_client handles connection timeout gracefully.""" + # Create a server that doesn't connect back + server_params = SocketServerParameters( + command=python, + args=["-c", "import time; time.sleep(10)"], + connection_timeout=1.0, # Set a short timeout for testing + ) + + # Should raise an error when connection times out + with pytest.raises(Exception) as exc_info: + async with socket_client(server_params) as (_, _): + pass + + # The error should be a TimeoutError wrapped in an ExceptionGroup + assert isinstance(exc_info.value, ExceptionGroup) + assert isinstance(exc_info.value.exceptions[0], TimeoutError) + + +@pytest.mark.anyio +async def test_socket_client_connection_refused(): + """Test that socket_client handles connection refused gracefully.""" + # Create a server that exits immediately + server_params = SocketServerParameters( + command=python, + args=["-c", "exit(0)"], + connection_timeout=1.0, # Set a short timeout for testing + ) + + # Should raise an error when connection is refused + with pytest.raises(Exception) as exc_info: + async with socket_client(server_params) as (_, _): + pass + + # The error should be a TimeoutError wrapped in an ExceptionGroup + assert isinstance(exc_info.value, ExceptionGroup) + assert isinstance(exc_info.value.exceptions[0], TimeoutError) + + +@pytest.mark.anyio +async def test_socket_client_port_zero(): + """Test that port 0 works correctly for client (auto-assigns port).""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) +s.send(b'test') +s.close() + """, + ], + port=0, # Should auto-assign + ) + + async with socket_client(server_params) as (read_stream, write_stream): + # The connection should succeed with an auto-assigned port + assert True + + +@pytest.mark.anyio +async def test_socket_client_encoding(): + """Test message encoding/decoding with different character sets.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) +while True: + data = s.recv(1024) + if not data: + break + s.send(data) +s.close() + """, + ], + encoding="utf-8", + encoding_error_handler="strict", + ) + + async with socket_client(server_params) as (read_stream, write_stream): + # Test messages with special characters + messages = [ + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", id=1, method="echo", params={"text": "Hello, 世界!"} + ) + ), + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", id=2, method="echo", params={"text": "¡Hola, мир!"} + ) + ), + ] + + async with write_stream: + for message in messages: + session_message = SessionMessage(message) + await write_stream.send(session_message) + + read_messages = [] + async with read_stream: + async for message in read_stream: + if isinstance(message, Exception): + raise message + + read_messages.append(message.message) + if len(read_messages) == 2: + break + + assert len(read_messages) == 2 + assert read_messages[0].root.params["text"] == "Hello, 世界!" + assert read_messages[1].root.params["text"] == "¡Hola, мир!" + + +@pytest.mark.anyio +async def test_socket_client_invalid_json(): + """Test handling of invalid JSON messages.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) +s.send(b'invalid json\\n') +s.close() + """, + ], + ) + + async with socket_client(server_params) as (read_stream, write_stream): + async for message in read_stream: + assert isinstance(message, Exception) + break + + +@pytest.mark.anyio +async def test_socket_client_cancellation_handling(): + """Test that socket_client handles cancellation gracefully.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys, time +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) +# Keep connection alive for a bit +time.sleep(2) +s.close() + """, + ], + ) + + # Test that cancellation works properly + with anyio.move_on_after(0.5) as cancel_scope: + async with socket_client(server_params) as (read_stream, write_stream): + # Wait a bit, then the move_on_after should cancel + await anyio.sleep(1) + + # The cancellation should have occurred + assert cancel_scope.cancelled_caught + + +@pytest.mark.anyio +async def test_socket_client_cleanup_timeout(): + """Test that socket_client cleanup has timeout protection.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys, time +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) + +# Send a message and keep connection alive briefly +s.send(b'{"jsonrpc": "2.0", "id": 1, "method": "test"}\\n') +time.sleep(1) # Brief delay, then exit normally +s.close() + """, + ], + ) + + # Test that cleanup completes within reasonable time + start_time = anyio.current_time() + + async with socket_client(server_params) as (read_stream, write_stream): + # Do some work + await anyio.sleep(0.1) + + end_time = anyio.current_time() + + # Normal cleanup should complete quickly (within 3 seconds) + # This tests that the cleanup mechanism works without hanging + assert end_time - start_time < 3.0 + + +@pytest.mark.anyio +async def test_socket_client_cleanup_mechanism(): + """Test that socket_client cleanup mechanism is robust.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys, time +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) + +# Send a test message +s.send(b'{"jsonrpc": "2.0", "id": 1, "method": "test"}\\n') + +# Close after brief delay +time.sleep(0.2) +s.close() + """, + ], + ) + + # Test that cleanup works correctly + async with socket_client(server_params) as (read_stream, write_stream): + # Process at least one message + async for message in read_stream: + if isinstance(message, Exception): + continue + # Exit after first valid message + break + + # If we reach here, cleanup worked properly + assert True + + +@pytest.mark.anyio +async def test_socket_client_reader_writer_exception_handling(): + """Test that socket reader/writer handle exceptions properly.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys, time +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) + +# Send some data then close abruptly +s.send(b'{"jsonrpc": "2.0", "id": 1, "method": "test"}\\n') +time.sleep(0.1) +s.close() # Close connection abruptly + """, + ], + ) + + async with socket_client(server_params) as (read_stream, write_stream): + # Should handle the abrupt connection close gracefully + messages_received = 0 + async for message in read_stream: + if isinstance(message, Exception): + # Exceptions in the stream are expected + continue + messages_received += 1 + if messages_received >= 1: + break + + assert messages_received >= 1 + + +@pytest.mark.anyio +async def test_socket_client_process_cleanup(): + """Test that socket_client cleans up processes properly.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys, time, os +pid = os.getpid() +print(f"Process PID: {pid}", file=sys.stderr) + +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) +time.sleep(0.5) +s.close() + """, + ], + ) + + async with socket_client(server_params) as (read_stream, write_stream): + # Brief interaction + await anyio.sleep(0.1) + + # Process should be cleaned up after context exit + # This is mainly to ensure no zombie processes remain + await anyio.sleep(0.1) # Give cleanup time to complete + + +@pytest.mark.anyio +async def test_socket_client_multiple_messages_with_cancellation(): + """Test handling multiple messages with cancellation.""" + server_params = SocketServerParameters( + command=python, + args=[ + "-c", + """ +import socket, sys, time, json +s = socket.socket() +s.connect(('127.0.0.1', int(sys.argv[2]))) + +# Send multiple messages +for i in range(10): + msg = {"jsonrpc": "2.0", "id": i, "method": "test", "params": {"counter": i}} + s.send((json.dumps(msg) + '\\n').encode()) + time.sleep(0.01) # Small delay between messages + +s.close() + """, + ], + ) + + messages_received = 0 + + with anyio.move_on_after(1.0) as cancel_scope: + async with socket_client(server_params) as (read_stream, write_stream): + async for message in read_stream: + if isinstance(message, Exception): + continue + messages_received += 1 + if messages_received >= 5: + # Cancel after receiving some messages + cancel_scope.cancel() + + # Should have received some messages before cancellation + assert messages_received >= 5 + assert cancel_scope.cancelled_caught