11from datetime import timedelta
2+ from typing import Any , Protocol
23
34from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
4- from pydantic import AnyUrl
5+ from pydantic import AnyUrl , TypeAdapter
56
67import mcp .types as types
7- from mcp .shared .session import BaseSession
8+ from mcp .shared .context import RequestContext
9+ from mcp .shared .session import BaseSession , RequestResponder
810from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
911
1012
13+ class SamplingFnT (Protocol ):
14+ async def __call__ (
15+ self ,
16+ context : RequestContext ["ClientSession" , Any ],
17+ params : types .CreateMessageRequestParams ,
18+ ) -> types .CreateMessageResult | types .ErrorData : ...
19+
20+
21+ class ListRootsFnT (Protocol ):
22+ async def __call__ (
23+ self , context : RequestContext ["ClientSession" , Any ]
24+ ) -> types .ListRootsResult | types .ErrorData : ...
25+
26+
27+ async def _default_sampling_callback (
28+ context : RequestContext ["ClientSession" , Any ],
29+ params : types .CreateMessageRequestParams ,
30+ ) -> types .CreateMessageResult | types .ErrorData :
31+ return types .ErrorData (
32+ code = types .INVALID_REQUEST ,
33+ message = "Sampling not supported" ,
34+ )
35+
36+
37+ async def _default_list_roots_callback (
38+ context : RequestContext ["ClientSession" , Any ],
39+ ) -> types .ListRootsResult | types .ErrorData :
40+ return types .ErrorData (
41+ code = types .INVALID_REQUEST ,
42+ message = "List roots not supported" ,
43+ )
44+
45+
46+ ClientResponse = TypeAdapter (types .ClientResult | types .ErrorData )
47+
48+
1149class ClientSession (
1250 BaseSession [
1351 types .ClientRequest ,
@@ -22,6 +60,8 @@ def __init__(
2260 read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
2361 write_stream : MemoryObjectSendStream [types .JSONRPCMessage ],
2462 read_timeout_seconds : timedelta | None = None ,
63+ sampling_callback : SamplingFnT | None = None ,
64+ list_roots_callback : ListRootsFnT | None = None ,
2565 ) -> None :
2666 super ().__init__ (
2767 read_stream ,
@@ -30,23 +70,34 @@ def __init__(
3070 types .ServerNotification ,
3171 read_timeout_seconds = read_timeout_seconds ,
3272 )
73+ self ._sampling_callback = sampling_callback or _default_sampling_callback
74+ self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
3375
3476 async def initialize (self ) -> types .InitializeResult :
77+ sampling = (
78+ types .SamplingCapability () if self ._sampling_callback is not None else None
79+ )
80+ roots = (
81+ types .RootsCapability (
82+ # TODO: Should this be based on whether we
83+ # _will_ send notifications, or only whether
84+ # they're supported?
85+ listChanged = True ,
86+ )
87+ if self ._list_roots_callback is not None
88+ else None
89+ )
90+
3591 result = await self .send_request (
3692 types .ClientRequest (
3793 types .InitializeRequest (
3894 method = "initialize" ,
3995 params = types .InitializeRequestParams (
4096 protocolVersion = types .LATEST_PROTOCOL_VERSION ,
4197 capabilities = types .ClientCapabilities (
42- sampling = None ,
98+ sampling = sampling ,
4399 experimental = None ,
44- roots = types .RootsCapability (
45- # TODO: Should this be based on whether we
46- # _will_ send notifications, or only whether
47- # they're supported?
48- listChanged = True
49- ),
100+ roots = roots ,
50101 ),
51102 clientInfo = types .Implementation (name = "mcp" , version = "0.1.0" ),
52103 ),
@@ -243,3 +294,32 @@ async def send_roots_list_changed(self) -> None:
243294 )
244295 )
245296 )
297+
298+ async def _received_request (
299+ self , responder : RequestResponder [types .ServerRequest , types .ClientResult ]
300+ ) -> None :
301+ ctx = RequestContext [ClientSession , Any ](
302+ request_id = responder .request_id ,
303+ meta = responder .request_meta ,
304+ session = self ,
305+ lifespan_context = None ,
306+ )
307+
308+ match responder .request .root :
309+ case types .CreateMessageRequest (params = params ):
310+ with responder :
311+ response = await self ._sampling_callback (ctx , params )
312+ client_response = ClientResponse .validate_python (response )
313+ await responder .respond (client_response )
314+
315+ case types .ListRootsRequest ():
316+ with responder :
317+ response = await self ._list_roots_callback (ctx )
318+ client_response = ClientResponse .validate_python (response )
319+ await responder .respond (client_response )
320+
321+ case types .PingRequest ():
322+ with responder :
323+ return await responder .respond (
324+ types .ClientResult (root = types .EmptyResult ())
325+ )
0 commit comments