Skip to content

Commit 8e57a92

Browse files
committed
fmt
1 parent e3dbcc3 commit 8e57a92

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

tests/mcp/test_message_handler.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,15 @@
44
import pytest
55
from mcp.client.session import MessageHandlerFnT
66
from mcp.shared.message import SessionMessage
7-
from mcp.types import Implementation, InitializeResult, ServerCapabilities
7+
from mcp.shared.session import RequestResponder
8+
from mcp.types import (
9+
ClientResult,
10+
Implementation,
11+
InitializeResult,
12+
ServerCapabilities,
13+
ServerNotification,
14+
ServerRequest,
15+
)
816

917
from agents.mcp.server import (
1018
MCPServerSse,
@@ -14,6 +22,13 @@
1422
)
1523

1624

25+
HandlerMessage = (
26+
RequestResponder[ServerRequest, ClientResult]
27+
| ServerNotification
28+
| Exception
29+
)
30+
31+
1732
class _StubClientSession:
1833
"""Stub ClientSession that records the configured message handler."""
1934

@@ -35,10 +50,12 @@ async def __aexit__(self, exc_type, exc, tb):
3550
return False
3651

3752
async def initialize(self) -> InitializeResult:
53+
capabilities = ServerCapabilities.model_construct()
54+
server_info = Implementation.model_construct(name="stub", version="1.0")
3855
return InitializeResult(
3956
protocolVersion="2024-11-05",
40-
capabilities=ServerCapabilities(),
41-
serverInfo=Implementation(name="stub", version="1.0"),
57+
capabilities=capabilities,
58+
serverInfo=server_info,
4259
)
4360

4461

@@ -81,7 +98,7 @@ def _recording_client_session(*args, **kwargs):
8198
monkeypatch.setattr("agents.mcp.server.ClientSession", _recording_client_session)
8299

83100
class _AsyncHandler:
84-
async def __call__(self, message):
101+
async def __call__(self, message: HandlerMessage) -> None:
85102
del message
86103

87104
handler: MessageHandlerFnT = _AsyncHandler()
@@ -106,7 +123,7 @@ async def __call__(self, message):
106123
)
107124
def test_message_handler_propagates_to_server_base(server_cls, params):
108125
class _AsyncHandler:
109-
async def __call__(self, message):
126+
async def __call__(self, message: HandlerMessage) -> None:
110127
del message
111128

112129
handler: MessageHandlerFnT = _AsyncHandler()

0 commit comments

Comments
 (0)