2
2
3
3
import anyio
4
4
import pytest
5
+ from mcp .client .session import MessageHandlerFnT
5
6
from mcp .shared .message import SessionMessage
6
- from mcp .types import InitializeResult
7
+ from mcp .types import Implementation , InitializeResult , ServerCapabilities
7
8
8
9
from agents .mcp .server import (
9
10
MCPServerSse ,
10
- MCPServerStreamableHttp ,
11
11
MCPServerStdio ,
12
+ MCPServerStreamableHttp ,
12
13
_MCPServerWithClientSession ,
13
14
)
14
- from mcp .client .session import MessageHandlerFnT
15
15
16
16
17
17
class _StubClientSession :
@@ -37,8 +37,8 @@ async def __aexit__(self, exc_type, exc, tb):
37
37
async def initialize (self ) -> InitializeResult :
38
38
return InitializeResult (
39
39
protocolVersion = "2024-11-05" ,
40
- capabilities = {} ,
41
- serverInfo = { " name" : " stub" , " version" : " 1.0"} ,
40
+ capabilities = ServerCapabilities () ,
41
+ serverInfo = Implementation ( name = " stub" , version = " 1.0") ,
42
42
)
43
43
44
44
@@ -53,9 +53,9 @@ def __init__(self, handler: MessageHandlerFnT | None):
53
53
def create_streams (self ):
54
54
@contextlib .asynccontextmanager
55
55
async def _streams ():
56
- send_stream , recv_stream = anyio .create_memory_object_stream [SessionMessage | Exception ](
57
- 1
58
- )
56
+ send_stream , recv_stream = anyio .create_memory_object_stream [
57
+ SessionMessage | Exception
58
+ ]( 1 )
59
59
try :
60
60
yield recv_stream , send_stream , None
61
61
finally :
@@ -80,8 +80,11 @@ def _recording_client_session(*args, **kwargs):
80
80
81
81
monkeypatch .setattr ("agents.mcp.server.ClientSession" , _recording_client_session )
82
82
83
- async def handler (message : SessionMessage ) -> None :
84
- del message
83
+ class _AsyncHandler :
84
+ async def __call__ (self , message ):
85
+ del message
86
+
87
+ handler : MessageHandlerFnT = _AsyncHandler ()
85
88
86
89
server = _MessageHandlerTestServer (handler )
87
90
@@ -102,8 +105,11 @@ async def handler(message: SessionMessage) -> None:
102
105
],
103
106
)
104
107
def test_message_handler_propagates_to_server_base (server_cls , params ):
105
- def handler (message : SessionMessage ) -> None :
106
- del message
108
+ class _AsyncHandler :
109
+ async def __call__ (self , message ):
110
+ del message
111
+
112
+ handler : MessageHandlerFnT = _AsyncHandler ()
107
113
108
114
server = server_cls (params , message_handler = handler )
109
115
0 commit comments