Skip to content

Commit a48cb68

Browse files
Added roots callback also
1 parent 6e4c8d4 commit a48cb68

File tree

4 files changed

+112
-21
lines changed

4 files changed

+112
-21
lines changed

src/mcp/client/session.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
from datetime import timedelta
2-
from typing import Awaitable, Callable
2+
from typing import Protocol, Any
33

44
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
55
from pydantic import AnyUrl
66

7+
from mcp.shared.context import RequestContext
78
import mcp.types as types
89
from mcp.shared.session import BaseSession, RequestResponder
910
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1011

11-
SamplingFnT = Callable[
12-
[types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult]
13-
]
12+
13+
class SamplingFnT(Protocol):
14+
async def __call__(
15+
self, context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams
16+
) -> types.CreateMessageResult:
17+
...
18+
19+
20+
class ListRootsFnT(Protocol):
21+
async def __call__(
22+
self, context: RequestContext["ClientSession", Any]
23+
) -> types.ListRootsResult:
24+
...
1425

1526

1627
class ClientSession(
@@ -22,14 +33,15 @@ class ClientSession(
2233
types.ServerNotification,
2334
]
2435
):
25-
sampling_callback: SamplingFnT | None = None
36+
_sampling_callback: SamplingFnT | None = None
2637

2738
def __init__(
2839
self,
2940
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
3041
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
3142
read_timeout_seconds: timedelta | None = None,
3243
sampling_callback: SamplingFnT | None = None,
44+
list_roots_callback: ListRootsFnT | None = None,
3345
) -> None:
3446
super().__init__(
3547
read_stream,
@@ -38,11 +50,22 @@ def __init__(
3850
types.ServerNotification,
3951
read_timeout_seconds=read_timeout_seconds,
4052
)
41-
self.sampling_callback = sampling_callback
53+
self._sampling_callback = sampling_callback
54+
self._list_roots_callback = list_roots_callback
4255

4356
async def initialize(self) -> types.InitializeResult:
4457
sampling = (
45-
types.SamplingCapability() if self.sampling_callback is not None else None
58+
types.SamplingCapability() if self._sampling_callback is not None else None
59+
)
60+
roots = (
61+
types.RootsCapability(
62+
# TODO: Should this be based on whether we
63+
# _will_ send notifications, or only whether
64+
# they're supported?
65+
listChanged=True,
66+
)
67+
if self._list_roots_callback is not None
68+
else None
4669
)
4770

4871
result = await self.send_request(
@@ -54,12 +77,7 @@ async def initialize(self) -> types.InitializeResult:
5477
capabilities=types.ClientCapabilities(
5578
sampling=sampling,
5679
experimental=None,
57-
roots=types.RootsCapability(
58-
# TODO: Should this be based on whether we
59-
# _will_ send notifications, or only whether
60-
# they're supported?
61-
listChanged=True
62-
),
80+
roots=roots,
6381
),
6482
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
6583
),
@@ -258,11 +276,29 @@ async def send_roots_list_changed(self) -> None:
258276
)
259277

260278
async def _received_request(
261-
self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"]
279+
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
262280
) -> None:
263-
if isinstance(responder.request.root, types.CreateMessageRequest):
264-
if self.sampling_callback is not None:
265-
response = await self.sampling_callback(responder.request.root.params)
266-
client_response = types.ClientResult(root=response)
281+
282+
ctx = RequestContext[ClientSession, Any](
283+
request_id=responder.request_id,
284+
meta=responder.request_meta,
285+
session=self,
286+
lifespan_context=None,
287+
)
288+
289+
match responder.request.root:
290+
case types.CreateMessageRequest:
291+
if self._sampling_callback is not None:
292+
response = await self._sampling_callback(ctx, responder.request.root.params)
293+
client_response = types.ClientResult(root=response)
294+
with responder:
295+
await responder.respond(client_response)
296+
case types.ListRootsRequest:
297+
if self._list_roots_callback is not None:
298+
response = await self._list_roots_callback(ctx)
299+
client_response = types.ClientResult(root=response)
300+
with responder:
301+
await responder.respond(client_response)
302+
case types.PingRequest:
267303
with responder:
268-
await responder.respond(client_response)
304+
await responder.respond(types.ClientResult(root=types.EmptyResult()))

src/mcp/shared/memory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import anyio
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111

12-
from mcp.client.session import ClientSession, SamplingFnT
12+
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
1313
from mcp.server import Server
1414
from mcp.types import JSONRPCMessage
1515

@@ -55,6 +55,7 @@ async def create_connected_server_and_client_session(
5555
server: Server,
5656
read_timeout_seconds: timedelta | None = None,
5757
sampling_callback: SamplingFnT | None = None,
58+
list_roots_callback: ListRootsFnT | None = None,
5859
raise_exceptions: bool = False,
5960
) -> AsyncGenerator[ClientSession, None]:
6061
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -82,6 +83,7 @@ async def create_connected_server_and_client_session(
8283
write_stream=client_write,
8384
read_timeout_seconds=read_timeout_seconds,
8485
sampling_callback=sampling_callback,
86+
list_roots_callback=list_roots_callback,
8587
) as client_session:
8688
await client_session.initialize()
8789
yield client_session
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from pydantic import FileUrl
2+
import pytest
3+
4+
from mcp.client.session import ClientSession
5+
from mcp.server.fastmcp.server import Context
6+
from mcp.shared.context import RequestContext
7+
from mcp.shared.memory import (
8+
create_connected_server_and_client_session as create_session,
9+
)
10+
from mcp.types import (
11+
ListRootsResult,
12+
Root,
13+
)
14+
15+
16+
@pytest.mark.anyio
17+
async def test_list_roots_callback():
18+
from mcp.server.fastmcp import FastMCP
19+
20+
server = FastMCP("test")
21+
22+
callback_return = ListRootsResult(roots=[
23+
Root(
24+
uri=FileUrl("test://users/fake/test"),
25+
name="Test Root 1",
26+
),
27+
Root(
28+
uri=FileUrl("test://users/fake/test/2"),
29+
name="Test Root 2",
30+
)
31+
])
32+
33+
async def list_roots_callback(
34+
context: RequestContext[ClientSession, None]
35+
) -> ListRootsResult:
36+
return callback_return
37+
38+
@server.tool("test_list_roots")
39+
async def test_list_roots(context: Context, message: str):
40+
roots = context.session.list_roots()
41+
assert roots == callback_return
42+
return True
43+
44+
async with create_session(
45+
server._mcp_server, list_roots_callback=list_roots_callback
46+
) as client_session:
47+
# Make a request to trigger sampling callback
48+
assert await client_session.call_tool(
49+
"test_list_roots", {"message": "test message"}
50+
)

tests/client/test_sampling_callback.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22

3+
from mcp.client.session import ClientSession
4+
from mcp.shared.context import RequestContext
35
from mcp.shared.memory import (
46
create_connected_server_and_client_session as create_session,
57
)
@@ -27,7 +29,8 @@ async def test_sampling_callback():
2729
)
2830

2931
async def sampling_callback(
30-
message: CreateMessageRequestParams,
32+
context: RequestContext[ClientSession, None],
33+
params: CreateMessageRequestParams,
3134
) -> CreateMessageResult:
3235
return callback_return
3336

0 commit comments

Comments
 (0)