diff --git a/mcp_python/client/session.py b/mcp_python/client/session.py index 266e741c9..82ec0b2bd 100644 --- a/mcp_python/client/session.py +++ b/mcp_python/client/session.py @@ -1,3 +1,5 @@ +from datetime import timedelta + from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -36,8 +38,15 @@ def __init__( self, read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_timeout_seconds: timedelta | None = None, ) -> None: - super().__init__(read_stream, write_stream, ServerRequest, ServerNotification) + super().__init__( + read_stream, + write_stream, + ServerRequest, + ServerNotification, + read_timeout_seconds=read_timeout_seconds, + ) async def initialize(self) -> InitializeResult: from mcp_python.types import ( diff --git a/mcp_python/server/__init__.py b/mcp_python/server/__init__.py index 71b373d63..ff8900fc7 100644 --- a/mcp_python/server/__init__.py +++ b/mcp_python/server/__init__.py @@ -18,6 +18,7 @@ ClientNotification, ClientRequest, CompleteRequest, + EmptyResult, ErrorData, JSONRPCMessage, ListPromptsRequest, @@ -27,6 +28,7 @@ ListToolsRequest, ListToolsResult, LoggingLevel, + PingRequest, ProgressNotification, Prompt, PromptReference, @@ -52,9 +54,11 @@ class Server: def __init__(self, name: str): self.name = name - self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {} + self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = { + PingRequest: _ping_handler, + } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} - logger.info(f"Initializing server '{name}'") + logger.debug(f"Initializing server '{name}'") def create_initialization_options(self) -> types.InitializationOptions: """Create initialization options from this server instance.""" @@ -63,9 +67,13 @@ def pkg_version(package: str) -> str: try: from importlib.metadata import version - return version(package) + v = version(package) + if v is not None: + return v except Exception: - return "unknown" + pass + + return "unknown" return types.InitializationOptions( server_name=self.name, @@ -330,6 +338,11 @@ async def run( read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[JSONRPCMessage], initialization_options: types.InitializationOptions, + # When True, exceptions are returned as messages to the client. + # When False, exceptions are raised, which will cause the server to shut down + # but also make tracing exceptions much easier during testing and when using + # in-process servers. + raise_exceptions: bool = False, ): with warnings.catch_warnings(record=True) as w: async with ServerSession( @@ -349,6 +362,7 @@ async def run( f"Dispatching request of type {type(req).__name__}" ) + token = None try: # Set our global state that can be retrieved via # app.get_request_context() @@ -360,12 +374,16 @@ async def run( ) ) response = await handler(req) - # Reset the global state after we are done - request_ctx.reset(token) except Exception as err: + if raise_exceptions: + raise err response = ErrorData( code=0, message=str(err), data=None ) + finally: + # Reset the global state after we are done + if token is not None: + request_ctx.reset(token) await message.respond(response) else: @@ -399,3 +417,7 @@ async def run( logger.info( f"Warning: {warning.category.__name__}: {warning.message}" ) + + +async def _ping_handler(request: PingRequest) -> ServerResult: + return ServerResult(EmptyResult()) diff --git a/mcp_python/shared/memory.py b/mcp_python/shared/memory.py new file mode 100644 index 000000000..a2917499a --- /dev/null +++ b/mcp_python/shared/memory.py @@ -0,0 +1,87 @@ +""" +In-memory transports +""" + +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import AsyncGenerator + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +from mcp_python.client.session import ClientSession +from mcp_python.server import Server +from mcp_python.types import JSONRPCMessage + +MessageStream = tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage] +] + +@asynccontextmanager +async def create_client_server_memory_streams() -> AsyncGenerator[ + tuple[MessageStream, MessageStream], + None +]: + """ + Creates a pair of bidirectional memory streams for client-server communication. + + Returns: + A tuple of (client_streams, server_streams) where each is a tuple of + (read_stream, write_stream) + """ + # Create streams for both directions + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](1) + + client_streams = (server_to_client_receive, client_to_server_send) + server_streams = (client_to_server_receive, server_to_client_send) + + async with ( + server_to_client_receive, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + ): + yield client_streams, server_streams + + +@asynccontextmanager +async def create_connected_server_and_client_session( + server: Server, + read_timeout_seconds: timedelta | None = None, + raise_exceptions: bool = False, +) -> AsyncGenerator[ClientSession, None]: + """Creates a ClientSession that is connected to a running MCP server.""" + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + # Create a cancel scope for the server task + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: server.run( + server_read, + server_write, + server.create_initialization_options(), + raise_exceptions=raise_exceptions, + ) + ) + + try: + async with ClientSession( + read_stream=client_read, + write_stream=client_write, + read_timeout_seconds=read_timeout_seconds, + ) as client_session: + await client_session.initialize() + yield client_session + finally: + tg.cancel_scope.cancel() diff --git a/mcp_python/shared/session.py b/mcp_python/shared/session.py index 3bc66fcd0..f063a33bd 100644 --- a/mcp_python/shared/session.py +++ b/mcp_python/shared/session.py @@ -1,8 +1,10 @@ from contextlib import AbstractAsyncContextManager +from datetime import timedelta from typing import Generic, TypeVar import anyio import anyio.lowlevel +import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel @@ -87,6 +89,8 @@ def __init__( write_stream: MemoryObjectSendStream[JSONRPCMessage], receive_request_type: type[ReceiveRequestT], receive_notification_type: type[ReceiveNotificationT], + # If none, reading will never time out + read_timeout_seconds: timedelta | None = None, ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -94,6 +98,7 @@ def __init__( self._request_id = 0 self._receive_request_type = receive_request_type self._receive_notification_type = receive_notification_type + self._read_timeout_seconds = read_timeout_seconds self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( anyio.create_memory_object_stream[ @@ -147,7 +152,25 @@ async def send_request( await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) - response_or_error = await response_stream_reader.receive() + try: + with anyio.fail_after( + None if self._read_timeout_seconds is None + else self._read_timeout_seconds.total_seconds() + ): + response_or_error = await response_stream_reader.receive() + except TimeoutError: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{self._read_timeout_seconds} seconds." + ), + ) + + ) + if isinstance(response_or_error, JSONRPCError): raise McpError(response_or_error.error) else: diff --git a/mcp_python/types.py b/mcp_python/types.py index 012122ed7..b3ab4dd98 100644 --- a/mcp_python/types.py +++ b/mcp_python/types.py @@ -141,16 +141,19 @@ class ErrorData(BaseModel): code: int """The error type that occurred.""" + message: str """ A short description of the error. The message SHOULD be limited to a concise single sentence. """ + data: Any | None = None """ Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.). """ + model_config = ConfigDict(extra="allow") diff --git a/pyproject.toml b/pyproject.toml index bebdd827b..96f24bbf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "mcp-python" -version = "0.4.0.dev" +version = "0.5.0dev" description = "Model Context Protocol implementation for Python" readme = "README.md" requires-python = ">=3.10" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..37ff5a4ec --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,28 @@ +import pytest +from pydantic import AnyUrl + +from mcp_python.server import Server +from mcp_python.server.types import InitializationOptions +from mcp_python.types import Resource, ServerCapabilities + +TEST_INITIALIZATION_OPTIONS = InitializationOptions( + server_name="my_mcp_server", + server_version="0.1.0", + capabilities=ServerCapabilities(), +) + +@pytest.fixture +def mcp_server() -> Server: + server = Server(name="test_server") + + @server.list_resources() + async def handle_list_resources(): + return [ + Resource( + uri=AnyUrl("memory://test"), + name="Test Resource", + description="A test resource" + ) + ] + + return server diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py new file mode 100644 index 000000000..60a90de3d --- /dev/null +++ b/tests/shared/test_memory.py @@ -0,0 +1,28 @@ +import pytest +from typing_extensions import AsyncGenerator + +from mcp_python.client.session import ClientSession +from mcp_python.server import Server +from mcp_python.shared.memory import ( + create_connected_server_and_client_session, +) +from mcp_python.types import ( + EmptyResult, +) + + +@pytest.fixture +async def client_connected_to_server( + mcp_server: Server, +) -> AsyncGenerator[ClientSession, None]: + async with create_connected_server_and_client_session(mcp_server) as client_session: + yield client_session + + +@pytest.mark.anyio +async def test_memory_server_and_client_connection( + client_connected_to_server: ClientSession, +): + """Shows how a client and server can communicate over memory streams.""" + response = await client_connected_to_server.send_ping() + assert isinstance(response, EmptyResult) diff --git a/uv.lock b/uv.lock index e085cced4..f0fac56f8 100644 --- a/uv.lock +++ b/uv.lock @@ -163,7 +163,7 @@ wheels = [ [[package]] name = "mcp-python" -version = "0.3.0.dev0" +version = "0.5.0.dev0" source = { editable = "." } dependencies = [ { name = "anyio" },