11from datetime import timedelta
2- from typing import Protocol , Any
2+ from typing import Any , Protocol
33
44from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
5- from pydantic import AnyUrl
5+ from pydantic import AnyUrl , TypeAdapter
66
7- from mcp .shared .context import RequestContext
87import mcp .types as types
8+ from mcp .shared .context import RequestContext
99from mcp .shared .session import BaseSession , RequestResponder
1010from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
1111
1212
1313class SamplingFnT (Protocol ):
1414 async def __call__ (
15- self , context : RequestContext ["ClientSession" , Any ], params : types .CreateMessageRequestParams
16- ) -> types .CreateMessageResult :
17- ...
15+ self ,
16+ context : RequestContext ["ClientSession" , Any ],
17+ params : types .CreateMessageRequestParams ,
18+ ) -> types .CreateMessageResult | types .ErrorData : ...
1819
1920
2021class ListRootsFnT (Protocol ):
2122 async def __call__ (
2223 self , context : RequestContext ["ClientSession" , Any ]
23- ) -> types .ListRootsResult :
24- ...
24+ ) -> types .ListRootsResult | types .ErrorData : ...
25+
26+
27+ async def _default_sampling_callback (
28+ context : RequestContext ["ClientSession" , Any ],
29+ params : types .CreateMessageRequestParams ,
30+ ) -> types .CreateMessageResult | types .ErrorData :
31+ return types .ErrorData (
32+ code = types .INVALID_REQUEST ,
33+ message = "Sampling not supported" ,
34+ )
35+
36+
37+ async def _default_list_roots_callback (
38+ context : RequestContext ["ClientSession" , Any ],
39+ ) -> types .ListRootsResult | types .ErrorData :
40+ return types .ErrorData (
41+ code = types .INVALID_REQUEST ,
42+ message = "List roots not supported" ,
43+ )
44+
45+
46+ ClientResponse = TypeAdapter (types .ClientResult | types .ErrorData )
2547
2648
2749class ClientSession (
@@ -33,8 +55,6 @@ class ClientSession(
3355 types .ServerNotification ,
3456 ]
3557):
36- _sampling_callback : SamplingFnT | None = None
37-
3858 def __init__ (
3959 self ,
4060 read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
@@ -50,8 +70,8 @@ def __init__(
5070 types .ServerNotification ,
5171 read_timeout_seconds = read_timeout_seconds ,
5272 )
53- self ._sampling_callback = sampling_callback
54- self ._list_roots_callback = list_roots_callback
73+ self ._sampling_callback = sampling_callback or _default_sampling_callback
74+ self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
5575
5676 async def initialize (self ) -> types .InitializeResult :
5777 sampling = (
@@ -278,27 +298,28 @@ async def send_roots_list_changed(self) -> None:
278298 async def _received_request (
279299 self , responder : RequestResponder [types .ServerRequest , types .ClientResult ]
280300 ) -> None :
281-
282301 ctx = RequestContext [ClientSession , Any ](
283302 request_id = responder .request_id ,
284303 meta = responder .request_meta ,
285304 session = self ,
286305 lifespan_context = None ,
287306 )
288-
307+
289308 match responder .request .root :
290309 case types .CreateMessageRequest (params = params ):
291- if self . _sampling_callback is not None :
310+ with responder :
292311 response = await self ._sampling_callback (ctx , params )
293- client_response = types . ClientResult ( root = response )
294- with responder :
295- await responder . respond ( client_response )
312+ client_response = ClientResponse . validate_python ( response )
313+ await responder . respond ( client_response )
314+
296315 case types .ListRootsRequest ():
297- if self . _list_roots_callback is not None :
316+ with responder :
298317 response = await self ._list_roots_callback (ctx )
299- client_response = types . ClientResult ( root = response )
300- with responder :
301- await responder . respond ( client_response )
318+ client_response = ClientResponse . validate_python ( response )
319+ await responder . respond ( client_response )
320+
302321 case types .PingRequest ():
303322 with responder :
304- await responder .respond (types .ClientResult (root = types .EmptyResult ()))
323+ return await responder .respond (
324+ types .ClientResult (root = types .EmptyResult ())
325+ )
0 commit comments