Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
from mcp.client.session import MessageHandlerFnT
from mcp.client.sse import sse_client
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
from mcp.shared.message import SessionMessage
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
use_structured_content: bool = False,
max_retry_attempts: int = 0,
retry_backoff_seconds_base: float = 1.0,
message_handler: MessageHandlerFnT | None = None,
):
"""
Args:
Expand All @@ -124,6 +126,8 @@ def __init__(
Defaults to no retries.
retry_backoff_seconds_base: The base delay, in seconds, used for exponential
backoff between retries.
message_handler: Optional handler invoked for session messages as delivered by the
ClientSession.
"""
super().__init__(use_structured_content=use_structured_content)
self.session: ClientSession | None = None
Expand All @@ -135,6 +139,7 @@ def __init__(
self.client_session_timeout_seconds = client_session_timeout_seconds
self.max_retry_attempts = max_retry_attempts
self.retry_backoff_seconds_base = retry_backoff_seconds_base
self.message_handler = message_handler

# The cache is always dirty at startup, so that we fetch tools at least once
self._cache_dirty = True
Expand Down Expand Up @@ -272,6 +277,7 @@ async def connect(self):
timedelta(seconds=self.client_session_timeout_seconds)
if self.client_session_timeout_seconds
else None,
message_handler=self.message_handler,
)
)
server_result = await session.initialize()
Expand Down Expand Up @@ -394,6 +400,7 @@ def __init__(
use_structured_content: bool = False,
max_retry_attempts: int = 0,
retry_backoff_seconds_base: float = 1.0,
message_handler: MessageHandlerFnT | None = None,
):
"""Create a new MCP server based on the stdio transport.

Expand Down Expand Up @@ -421,6 +428,8 @@ def __init__(
Defaults to no retries.
retry_backoff_seconds_base: The base delay, in seconds, for exponential
backoff between retries.
message_handler: Optional handler invoked for session messages as delivered by the
ClientSession.
"""
super().__init__(
cache_tools_list,
Expand All @@ -429,6 +438,7 @@ def __init__(
use_structured_content,
max_retry_attempts,
retry_backoff_seconds_base,
message_handler=message_handler,
)

self.params = StdioServerParameters(
Expand Down Expand Up @@ -492,6 +502,7 @@ def __init__(
use_structured_content: bool = False,
max_retry_attempts: int = 0,
retry_backoff_seconds_base: float = 1.0,
message_handler: MessageHandlerFnT | None = None,
):
"""Create a new MCP server based on the HTTP with SSE transport.

Expand Down Expand Up @@ -521,6 +532,8 @@ def __init__(
Defaults to no retries.
retry_backoff_seconds_base: The base delay, in seconds, for exponential
backoff between retries.
message_handler: Optional handler invoked for session messages as delivered by the
ClientSession.
"""
super().__init__(
cache_tools_list,
Expand All @@ -529,6 +542,7 @@ def __init__(
use_structured_content,
max_retry_attempts,
retry_backoff_seconds_base,
message_handler=message_handler,
)

self.params = params
Expand Down Expand Up @@ -595,6 +609,7 @@ def __init__(
use_structured_content: bool = False,
max_retry_attempts: int = 0,
retry_backoff_seconds_base: float = 1.0,
message_handler: MessageHandlerFnT | None = None,
):
"""Create a new MCP server based on the Streamable HTTP transport.

Expand Down Expand Up @@ -625,6 +640,8 @@ def __init__(
Defaults to no retries.
retry_backoff_seconds_base: The base delay, in seconds, for exponential
backoff between retries.
message_handler: Optional handler invoked for session messages as delivered by the
ClientSession.
"""
super().__init__(
cache_tools_list,
Expand All @@ -633,6 +650,7 @@ def __init__(
use_structured_content,
max_retry_attempts,
retry_backoff_seconds_base,
message_handler=message_handler,
)

self.params = params
Expand Down
131 changes: 131 additions & 0 deletions tests/mcp/test_message_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import contextlib

import anyio
import pytest
from mcp.client.session import MessageHandlerFnT
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.types import (
ClientResult,
Implementation,
InitializeResult,
ServerCapabilities,
ServerNotification,
ServerRequest,
)

from agents.mcp.server import (
MCPServerSse,
MCPServerStdio,
MCPServerStreamableHttp,
_MCPServerWithClientSession,
)

HandlerMessage = (
RequestResponder[ServerRequest, ClientResult]
| ServerNotification
| Exception
)


class _StubClientSession:
"""Stub ClientSession that records the configured message handler."""

def __init__(
self,
read_stream,
write_stream,
read_timeout_seconds,
*,
message_handler=None,
**_: object,
) -> None:
self.message_handler = message_handler

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
return False

async def initialize(self) -> InitializeResult:
capabilities = ServerCapabilities.model_construct()
server_info = Implementation.model_construct(name="stub", version="1.0")
return InitializeResult(
protocolVersion="2024-11-05",
capabilities=capabilities,
serverInfo=server_info,
)


class _MessageHandlerTestServer(_MCPServerWithClientSession):
def __init__(self, handler: MessageHandlerFnT | None):
super().__init__(
cache_tools_list=False,
client_session_timeout_seconds=None,
message_handler=handler,
)

def create_streams(self):
@contextlib.asynccontextmanager
async def _streams():
send_stream, recv_stream = (
anyio.create_memory_object_stream[SessionMessage | Exception](1))
try:
yield recv_stream, send_stream, None
finally:
await recv_stream.aclose()
await send_stream.aclose()

return _streams()

@property
def name(self) -> str:
return "test-server"


@pytest.mark.asyncio
async def test_client_session_receives_message_handler(monkeypatch):
captured: dict[str, object] = {}

def _recording_client_session(*args, **kwargs):
session = _StubClientSession(*args, **kwargs)
captured["message_handler"] = session.message_handler
return session

monkeypatch.setattr("agents.mcp.server.ClientSession", _recording_client_session)

class _AsyncHandler:
async def __call__(self, message: HandlerMessage) -> None:
del message

handler: MessageHandlerFnT = _AsyncHandler()

server = _MessageHandlerTestServer(handler)

try:
await server.connect()
finally:
await server.cleanup()

assert captured["message_handler"] is handler


@pytest.mark.parametrize(
"server_cls, params",
[
(MCPServerSse, {"url": "https://example.com"}),
(MCPServerStreamableHttp, {"url": "https://example.com"}),
(MCPServerStdio, {"command": "python"}),
],
)
def test_message_handler_propagates_to_server_base(server_cls, params):
class _AsyncHandler:
async def __call__(self, message: HandlerMessage) -> None:
del message

handler: MessageHandlerFnT = _AsyncHandler()

server = server_cls(params, message_handler=handler)

assert server.message_handler is handler