11from datetime import timedelta
2- from typing import Awaitable , Callable
2+ from typing import Protocol , Any
33
44from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
55from pydantic import AnyUrl
66
7+ from mcp .shared .context import RequestContext
78import mcp .types as types
89from mcp .shared .session import BaseSession , RequestResponder
910from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
1011
11- SamplingFnT = Callable [
12- [types .CreateMessageRequestParams ], Awaitable [types .CreateMessageResult ]
13- ]
12+
13+ class SamplingFnT (Protocol ):
14+ async def __call__ (
15+ self , context : RequestContext ["ClientSession" , Any ], params : types .CreateMessageRequestParams
16+ ) -> types .CreateMessageResult :
17+ ...
18+
19+
20+ class ListRootsFnT (Protocol ):
21+ async def __call__ (
22+ self , context : RequestContext ["ClientSession" , Any ]
23+ ) -> types .ListRootsResult :
24+ ...
1425
1526
1627class ClientSession (
@@ -22,14 +33,15 @@ class ClientSession(
2233 types .ServerNotification ,
2334 ]
2435):
25- sampling_callback : SamplingFnT | None = None
36+ _sampling_callback : SamplingFnT | None = None
2637
2738 def __init__ (
2839 self ,
2940 read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
3041 write_stream : MemoryObjectSendStream [types .JSONRPCMessage ],
3142 read_timeout_seconds : timedelta | None = None ,
3243 sampling_callback : SamplingFnT | None = None ,
44+ list_roots_callback : ListRootsFnT | None = None ,
3345 ) -> None :
3446 super ().__init__ (
3547 read_stream ,
@@ -38,11 +50,22 @@ def __init__(
3850 types .ServerNotification ,
3951 read_timeout_seconds = read_timeout_seconds ,
4052 )
41- self .sampling_callback = sampling_callback
53+ self ._sampling_callback = sampling_callback
54+ self ._list_roots_callback = list_roots_callback
4255
4356 async def initialize (self ) -> types .InitializeResult :
4457 sampling = (
45- types .SamplingCapability () if self .sampling_callback is not None else None
58+ types .SamplingCapability () if self ._sampling_callback is not None else None
59+ )
60+ roots = (
61+ types .RootsCapability (
62+ # TODO: Should this be based on whether we
63+ # _will_ send notifications, or only whether
64+ # they're supported?
65+ listChanged = True ,
66+ )
67+ if self ._list_roots_callback is not None
68+ else None
4669 )
4770
4871 result = await self .send_request (
@@ -54,12 +77,7 @@ async def initialize(self) -> types.InitializeResult:
5477 capabilities = types .ClientCapabilities (
5578 sampling = sampling ,
5679 experimental = None ,
57- roots = types .RootsCapability (
58- # TODO: Should this be based on whether we
59- # _will_ send notifications, or only whether
60- # they're supported?
61- listChanged = True
62- ),
80+ roots = roots ,
6381 ),
6482 clientInfo = types .Implementation (name = "mcp" , version = "0.1.0" ),
6583 ),
@@ -258,11 +276,29 @@ async def send_roots_list_changed(self) -> None:
258276 )
259277
260278 async def _received_request (
261- self , responder : RequestResponder [" types.ServerRequest" , " types.ClientResult" ]
279+ self , responder : RequestResponder [types .ServerRequest , types .ClientResult ]
262280 ) -> None :
263- if isinstance (responder .request .root , types .CreateMessageRequest ):
264- if self .sampling_callback is not None :
265- response = await self .sampling_callback (responder .request .root .params )
266- client_response = types .ClientResult (root = response )
281+
282+ ctx = RequestContext [ClientSession , Any ](
283+ request_id = responder .request_id ,
284+ meta = responder .request_meta ,
285+ session = self ,
286+ lifespan_context = None ,
287+ )
288+
289+ match responder .request .root :
290+ case types .CreateMessageRequest :
291+ if self ._sampling_callback is not None :
292+ response = await self ._sampling_callback (ctx , responder .request .root .params )
293+ client_response = types .ClientResult (root = response )
294+ with responder :
295+ await responder .respond (client_response )
296+ case types .ListRootsRequest :
297+ if self ._list_roots_callback is not None :
298+ response = await self ._list_roots_callback (ctx )
299+ client_response = types .ClientResult (root = response )
300+ with responder :
301+ await responder .respond (client_response )
302+ case types .PingRequest :
267303 with responder :
268- await responder .respond (client_response )
304+ await responder .respond (types . ClientResult ( root = types . EmptyResult ()) )
0 commit comments