Skip to content

Commit f739155

Browse files
committed
Expose MCP message handler configuration
1 parent e87552a commit f739155

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

src/agents/mcp/server.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mcp import ClientSession, StdioServerParameters, Tool as MCPTool, stdio_client
1414
from mcp.client.sse import sse_client
1515
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
16+
from mcp.client.session import MessageHandlerFnT
1617
from mcp.shared.message import SessionMessage
1718
from mcp.types import CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult
1819
from typing_extensions import NotRequired, TypedDict
@@ -103,6 +104,7 @@ def __init__(
103104
use_structured_content: bool = False,
104105
max_retry_attempts: int = 0,
105106
retry_backoff_seconds_base: float = 1.0,
107+
message_handler: MessageHandlerFnT | None = None,
106108
):
107109
"""
108110
Args:
@@ -124,6 +126,8 @@ def __init__(
124126
Defaults to no retries.
125127
retry_backoff_seconds_base: The base delay, in seconds, used for exponential
126128
backoff between retries.
129+
message_handler: Optional handler invoked for session messages as delivered by the
130+
ClientSession.
127131
"""
128132
super().__init__(use_structured_content=use_structured_content)
129133
self.session: ClientSession | None = None
@@ -135,6 +139,7 @@ def __init__(
135139
self.client_session_timeout_seconds = client_session_timeout_seconds
136140
self.max_retry_attempts = max_retry_attempts
137141
self.retry_backoff_seconds_base = retry_backoff_seconds_base
142+
self.message_handler = message_handler
138143

139144
# The cache is always dirty at startup, so that we fetch tools at least once
140145
self._cache_dirty = True
@@ -272,6 +277,7 @@ async def connect(self):
272277
timedelta(seconds=self.client_session_timeout_seconds)
273278
if self.client_session_timeout_seconds
274279
else None,
280+
message_handler=self.message_handler,
275281
)
276282
)
277283
server_result = await session.initialize()
@@ -394,6 +400,7 @@ def __init__(
394400
use_structured_content: bool = False,
395401
max_retry_attempts: int = 0,
396402
retry_backoff_seconds_base: float = 1.0,
403+
message_handler: MessageHandlerFnT | None = None,
397404
):
398405
"""Create a new MCP server based on the stdio transport.
399406
@@ -421,6 +428,8 @@ def __init__(
421428
Defaults to no retries.
422429
retry_backoff_seconds_base: The base delay, in seconds, for exponential
423430
backoff between retries.
431+
message_handler: Optional handler invoked for session messages as delivered by the
432+
ClientSession.
424433
"""
425434
super().__init__(
426435
cache_tools_list,
@@ -429,6 +438,7 @@ def __init__(
429438
use_structured_content,
430439
max_retry_attempts,
431440
retry_backoff_seconds_base,
441+
message_handler=message_handler,
432442
)
433443

434444
self.params = StdioServerParameters(
@@ -492,6 +502,7 @@ def __init__(
492502
use_structured_content: bool = False,
493503
max_retry_attempts: int = 0,
494504
retry_backoff_seconds_base: float = 1.0,
505+
message_handler: MessageHandlerFnT | None = None,
495506
):
496507
"""Create a new MCP server based on the HTTP with SSE transport.
497508
@@ -521,6 +532,8 @@ def __init__(
521532
Defaults to no retries.
522533
retry_backoff_seconds_base: The base delay, in seconds, for exponential
523534
backoff between retries.
535+
message_handler: Optional handler invoked for session messages as delivered by the
536+
ClientSession.
524537
"""
525538
super().__init__(
526539
cache_tools_list,
@@ -529,6 +542,7 @@ def __init__(
529542
use_structured_content,
530543
max_retry_attempts,
531544
retry_backoff_seconds_base,
545+
message_handler=message_handler,
532546
)
533547

534548
self.params = params
@@ -592,6 +606,7 @@ def __init__(
592606
use_structured_content: bool = False,
593607
max_retry_attempts: int = 0,
594608
retry_backoff_seconds_base: float = 1.0,
609+
message_handler: MessageHandlerFnT | None = None,
595610
):
596611
"""Create a new MCP server based on the Streamable HTTP transport.
597612
@@ -622,6 +637,8 @@ def __init__(
622637
Defaults to no retries.
623638
retry_backoff_seconds_base: The base delay, in seconds, for exponential
624639
backoff between retries.
640+
message_handler: Optional handler invoked for session messages as delivered by the
641+
ClientSession.
625642
"""
626643
super().__init__(
627644
cache_tools_list,
@@ -630,6 +647,7 @@ def __init__(
630647
use_structured_content,
631648
max_retry_attempts,
632649
retry_backoff_seconds_base,
650+
message_handler=message_handler,
633651
)
634652

635653
self.params = params

tests/mcp/test_message_handler.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import contextlib
2+
3+
import anyio
4+
import pytest
5+
from mcp.shared.message import SessionMessage
6+
from mcp.types import InitializeResult
7+
8+
from agents.mcp.server import (
9+
MCPServerSse,
10+
MCPServerStreamableHttp,
11+
MCPServerStdio,
12+
_MCPServerWithClientSession,
13+
)
14+
from mcp.client.session import MessageHandlerFnT
15+
16+
17+
class _StubClientSession:
18+
"""Stub ClientSession that records the configured message handler."""
19+
20+
def __init__(
21+
self,
22+
read_stream,
23+
write_stream,
24+
read_timeout_seconds,
25+
*,
26+
message_handler=None,
27+
**_: object,
28+
) -> None:
29+
self.message_handler = message_handler
30+
31+
async def __aenter__(self):
32+
return self
33+
34+
async def __aexit__(self, exc_type, exc, tb):
35+
return False
36+
37+
async def initialize(self) -> InitializeResult:
38+
return InitializeResult(
39+
protocolVersion="2024-11-05",
40+
capabilities={},
41+
serverInfo={"name": "stub", "version": "1.0"},
42+
)
43+
44+
45+
class _MessageHandlerTestServer(_MCPServerWithClientSession):
46+
def __init__(self, handler: MessageHandlerFnT | None):
47+
super().__init__(
48+
cache_tools_list=False,
49+
client_session_timeout_seconds=None,
50+
message_handler=handler,
51+
)
52+
53+
def create_streams(self):
54+
@contextlib.asynccontextmanager
55+
async def _streams():
56+
send_stream, recv_stream = anyio.create_memory_object_stream[SessionMessage | Exception](
57+
1
58+
)
59+
try:
60+
yield recv_stream, send_stream, None
61+
finally:
62+
await recv_stream.aclose()
63+
await send_stream.aclose()
64+
65+
return _streams()
66+
67+
@property
68+
def name(self) -> str:
69+
return "test-server"
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_client_session_receives_message_handler(monkeypatch):
74+
captured: dict[str, object] = {}
75+
76+
def _recording_client_session(*args, **kwargs):
77+
session = _StubClientSession(*args, **kwargs)
78+
captured["message_handler"] = session.message_handler
79+
return session
80+
81+
monkeypatch.setattr("agents.mcp.server.ClientSession", _recording_client_session)
82+
83+
async def handler(message: SessionMessage) -> None:
84+
del message
85+
86+
server = _MessageHandlerTestServer(handler)
87+
88+
try:
89+
await server.connect()
90+
finally:
91+
await server.cleanup()
92+
93+
assert captured["message_handler"] is handler
94+
95+
96+
@pytest.mark.parametrize(
97+
"server_cls, params",
98+
[
99+
(MCPServerSse, {"url": "https://example.com"}),
100+
(MCPServerStreamableHttp, {"url": "https://example.com"}),
101+
(MCPServerStdio, {"command": "python"}),
102+
],
103+
)
104+
def test_message_handler_propagates_to_server_base(server_cls, params):
105+
def handler(message: SessionMessage) -> None:
106+
del message
107+
108+
server = server_cls(params, message_handler=handler)
109+
110+
assert server.message_handler is handler

0 commit comments

Comments
 (0)