Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mcp_python/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,15 @@ def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
read_timeout_seconds: int | float | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIOLI: Should this be a timedelta to avoid unit errors entirely?

) -> None:
super().__init__(read_stream, write_stream, ServerRequest, ServerNotification)
super().__init__(
read_stream,
write_stream,
ServerRequest,
ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)

async def initialize(self) -> InitializeResult:
from mcp_python.types import (
Expand Down
34 changes: 28 additions & 6 deletions mcp_python/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ClientNotification,
ClientRequest,
CompleteRequest,
EmptyResult,
ErrorData,
JSONRPCMessage,
ListPromptsRequest,
Expand All @@ -27,6 +28,7 @@
ListToolsRequest,
ListToolsResult,
LoggingLevel,
PingRequest,
ProgressNotification,
Prompt,
PromptReference,
Expand All @@ -52,9 +54,11 @@
class Server:
def __init__(self, name: str):
self.name = name
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {}
self.request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]] = {
PingRequest: _ping_handler,
}
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
logger.info(f"Initializing server '{name}'")
logger.debug(f"Initializing server '{name}'")

def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""
Expand All @@ -63,9 +67,13 @@ def pkg_version(package: str) -> str:
try:
from importlib.metadata import version

return version(package)
v = version(package)
if v is not None:
return v
except Exception:
return "unknown"
pass

return "unknown"

return types.InitializationOptions(
server_name=self.name,
Expand Down Expand Up @@ -330,6 +338,11 @@ async def run(
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions,
# When True, exceptions are returned as messages to the client.
# When False, exceptions are raised, which will cause the server to shut down
# but also make tracing exceptions much easier during testing and when using
# in-process servers.
raise_exceptions: bool = False,
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(
Expand All @@ -349,6 +362,7 @@ async def run(
f"Dispatching request of type {type(req).__name__}"
)

token = None
try:
# Set our global state that can be retrieved via
# app.get_request_context()
Expand All @@ -360,12 +374,16 @@ async def run(
)
)
response = await handler(req)
# Reset the global state after we are done
request_ctx.reset(token)
except Exception as err:
if raise_exceptions:
raise err
response = ErrorData(
code=0, message=str(err), data=None
)
finally:
# Reset the global state after we are done
if token is not None:
request_ctx.reset(token)

await message.respond(response)
else:
Expand Down Expand Up @@ -399,3 +417,7 @@ async def run(
logger.info(
f"Warning: {warning.category.__name__}: {warning.message}"
)


async def _ping_handler(request: PingRequest) -> ServerResult:
return ServerResult(EmptyResult())
83 changes: 83 additions & 0 deletions mcp_python/server/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
In-memory transports
"""

from contextlib import asynccontextmanager
from typing import AsyncGenerator

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.types import JSONRPCMessage

MessageStream = tuple[
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
MemoryObjectSendStream[JSONRPCMessage]
]

@asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[
tuple[MessageStream, MessageStream],
None
]:
"""
Creates a pair of bidirectional memory streams for client-server communication.
Returns:
A tuple of (client_streams, server_streams) where each is a tuple of
(read_stream, write_stream)
"""
# Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
JSONRPCMessage | Exception
](1)

client_streams = (server_to_client_receive, client_to_server_send)
server_streams = (client_to_server_receive, server_to_client_send)

async with (
server_to_client_receive,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
):
yield client_streams, server_streams


@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ServerSession that is connected to the `server`."""
async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams

# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
server.run,
server_read,
server_write,
server.create_initialization_options(),
)

try:
# Client session could be created here using client_read and
# client_write This would allow testing the server with a client
# in the same process
async with ClientSession(
read_stream=client_read, write_stream=client_write
) as client_session:
await client_session.initialize()
yield client_session
finally:
tg.cancel_scope.cancel()
20 changes: 19 additions & 1 deletion mcp_python/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,16 @@ def __init__(
write_stream: MemoryObjectSendStream[JSONRPCMessage],
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: int | float | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds

self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
Expand Down Expand Up @@ -147,7 +150,22 @@ async def send_request(

await self._write_stream.send(JSONRPCMessage(jsonrpc_request))

response_or_error = await response_stream_reader.receive()
try:
with anyio.fail_after(self._read_timeout_seconds):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=408,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{self._read_timeout_seconds} seconds."
),
)

)

if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
Expand Down
3 changes: 3 additions & 0 deletions mcp_python/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,19 @@ class ErrorData(BaseModel):

code: int
"""The error type that occurred."""

message: str
"""
A short description of the error. The message SHOULD be limited to a concise single
sentence.
"""

data: Any | None = None
"""
Additional information about the error. The value of this member is defined by the
sender (e.g. detailed error information, nested errors etc.).
"""

model_config = ConfigDict(extra="allow")


Expand Down
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from pydantic import AnyUrl

from mcp_python.server import Server
from mcp_python.server.types import InitializationOptions
from mcp_python.types import Resource, ServerCapabilities

TEST_INITIALIZATION_OPTIONS = InitializationOptions(
server_name="my_mcp_server",
server_version="0.1.0",
capabilities=ServerCapabilities(),
)

@pytest.fixture
def mcp_server() -> Server:
server = Server(name="test_server")

@server.list_resources()
async def handle_list_resources():
return [
Resource(
uri=AnyUrl("memory://test"),
name="Test Resource",
description="A test resource"
)
]

return server
28 changes: 28 additions & 0 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
from typing_extensions import AsyncGenerator

from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.server.memory import (
create_connected_server_and_client_session,
)
from mcp_python.types import (
EmptyResult,
)


@pytest.fixture
async def client_connected_to_server(
mcp_server: Server,
) -> AsyncGenerator[ClientSession, None]:
async with create_connected_server_and_client_session(mcp_server) as client_session:
yield client_session


@pytest.mark.anyio
async def test_memory_server_and_client_connection(
client_connected_to_server: ClientSession,
):
"""Shows how a client and server can communicate over memory streams."""
response = await client_connected_to_server.send_ping()
assert isinstance(response, EmptyResult)
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.