Skip to content

Commit 1f146f6

Browse files
authored
Merge branch 'main' into main
2 parents b6e4576 + 775f879 commit 1f146f6

15 files changed

+323
-46
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,21 @@ server_params = StdioServerParameters(
476476
env=None # Optional environment variables
477477
)
478478

479+
# Optional: create a sampling callback
480+
async def handle_sampling_message(message: types.CreateMessageRequestParams) -> types.CreateMessageResult:
481+
return types.CreateMessageResult(
482+
role="assistant",
483+
content=types.TextContent(
484+
type="text",
485+
text="Hello, world! from model",
486+
),
487+
model="gpt-3.5-turbo",
488+
stopReason="endTurn",
489+
)
490+
479491
async def run():
480492
async with stdio_client(server_params) as (read, write):
481-
async with ClientSession(read, write) as session:
493+
async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session:
482494
# Initialize the connection
483495
await session.initialize()
484496

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "mcp"
3-
version = "1.3.0.dev0"
3+
version = "1.4.0.dev0"
44
description = "Model Context Protocol SDK"
55
readme = "README.md"
66
requires-python = ">=3.10"

src/mcp/client/session.py

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,51 @@
11
from datetime import timedelta
2+
from typing import Any, Protocol
23

34
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4-
from pydantic import AnyUrl
5+
from pydantic import AnyUrl, TypeAdapter
56

67
import mcp.types as types
7-
from mcp.shared.session import BaseSession
8+
from mcp.shared.context import RequestContext
9+
from mcp.shared.session import BaseSession, RequestResponder
810
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
911

1012

13+
class SamplingFnT(Protocol):
14+
async def __call__(
15+
self,
16+
context: RequestContext["ClientSession", Any],
17+
params: types.CreateMessageRequestParams,
18+
) -> types.CreateMessageResult | types.ErrorData: ...
19+
20+
21+
class ListRootsFnT(Protocol):
22+
async def __call__(
23+
self, context: RequestContext["ClientSession", Any]
24+
) -> types.ListRootsResult | types.ErrorData: ...
25+
26+
27+
async def _default_sampling_callback(
28+
context: RequestContext["ClientSession", Any],
29+
params: types.CreateMessageRequestParams,
30+
) -> types.CreateMessageResult | types.ErrorData:
31+
return types.ErrorData(
32+
code=types.INVALID_REQUEST,
33+
message="Sampling not supported",
34+
)
35+
36+
37+
async def _default_list_roots_callback(
38+
context: RequestContext["ClientSession", Any],
39+
) -> types.ListRootsResult | types.ErrorData:
40+
return types.ErrorData(
41+
code=types.INVALID_REQUEST,
42+
message="List roots not supported",
43+
)
44+
45+
46+
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
47+
48+
1149
class ClientSession(
1250
BaseSession[
1351
types.ClientRequest,
@@ -22,6 +60,8 @@ def __init__(
2260
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
2361
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
2462
read_timeout_seconds: timedelta | None = None,
63+
sampling_callback: SamplingFnT | None = None,
64+
list_roots_callback: ListRootsFnT | None = None,
2565
) -> None:
2666
super().__init__(
2767
read_stream,
@@ -30,23 +70,34 @@ def __init__(
3070
types.ServerNotification,
3171
read_timeout_seconds=read_timeout_seconds,
3272
)
73+
self._sampling_callback = sampling_callback or _default_sampling_callback
74+
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
3375

3476
async def initialize(self) -> types.InitializeResult:
77+
sampling = (
78+
types.SamplingCapability() if self._sampling_callback is not None else None
79+
)
80+
roots = (
81+
types.RootsCapability(
82+
# TODO: Should this be based on whether we
83+
# _will_ send notifications, or only whether
84+
# they're supported?
85+
listChanged=True,
86+
)
87+
if self._list_roots_callback is not None
88+
else None
89+
)
90+
3591
result = await self.send_request(
3692
types.ClientRequest(
3793
types.InitializeRequest(
3894
method="initialize",
3995
params=types.InitializeRequestParams(
4096
protocolVersion=types.LATEST_PROTOCOL_VERSION,
4197
capabilities=types.ClientCapabilities(
42-
sampling=None,
98+
sampling=sampling,
4399
experimental=None,
44-
roots=types.RootsCapability(
45-
# TODO: Should this be based on whether we
46-
# _will_ send notifications, or only whether
47-
# they're supported?
48-
listChanged=True
49-
),
100+
roots=roots,
50101
),
51102
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
52103
),
@@ -243,3 +294,32 @@ async def send_roots_list_changed(self) -> None:
243294
)
244295
)
245296
)
297+
298+
async def _received_request(
299+
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
300+
) -> None:
301+
ctx = RequestContext[ClientSession, Any](
302+
request_id=responder.request_id,
303+
meta=responder.request_meta,
304+
session=self,
305+
lifespan_context=None,
306+
)
307+
308+
match responder.request.root:
309+
case types.CreateMessageRequest(params=params):
310+
with responder:
311+
response = await self._sampling_callback(ctx, params)
312+
client_response = ClientResponse.validate_python(response)
313+
await responder.respond(client_response)
314+
315+
case types.ListRootsRequest():
316+
with responder:
317+
response = await self._list_roots_callback(ctx)
318+
client_response = ClientResponse.validate_python(response)
319+
await responder.respond(client_response)
320+
321+
case types.PingRequest():
322+
with responder:
323+
return await responder.respond(
324+
types.ClientResult(root=types.EmptyResult())
325+
)

src/mcp/server/fastmcp/server.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import json
55
import re
6-
from collections.abc import AsyncIterator
6+
from collections.abc import AsyncIterator, Iterable
77
from contextlib import (
88
AbstractAsyncContextManager,
99
asynccontextmanager,
@@ -34,6 +34,7 @@
3434
from mcp.server.lowlevel.server import (
3535
lifespan as default_lifespan,
3636
)
37+
from mcp.server.session import ServerSession
3738
from mcp.server.sse import SseServerTransport
3839
from mcp.server.stdio import stdio_server
3940
from mcp.shared.context import RequestContext
@@ -235,7 +236,7 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]:
235236
for template in templates
236237
]
237238

238-
async def read_resource(self, uri: AnyUrl | str) -> ReadResourceContents:
239+
async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]:
239240
"""Read a resource by URI."""
240241

241242
resource = await self._resource_manager.get_resource(uri)
@@ -244,7 +245,7 @@ async def read_resource(self, uri: AnyUrl | str) -> ReadResourceContents:
244245

245246
try:
246247
content = await resource.read()
247-
return ReadResourceContents(content=content, mime_type=resource.mime_type)
248+
return [ReadResourceContents(content=content, mime_type=resource.mime_type)]
248249
except Exception as e:
249250
logger.error(f"Error reading resource {uri}: {e}")
250251
raise ResourceError(str(e))
@@ -597,7 +598,7 @@ def my_tool(x: int, ctx: Context) -> str:
597598
The context is optional - tools that don't need it can omit the parameter.
598599
"""
599600

600-
_request_context: RequestContext | None
601+
_request_context: RequestContext[ServerSession, Any] | None
601602
_fastmcp: FastMCP | None
602603

603604
def __init__(
@@ -648,7 +649,7 @@ async def report_progress(
648649
progress_token=progress_token, progress=progress, total=total
649650
)
650651

651-
async def read_resource(self, uri: str | AnyUrl) -> ReadResourceContents:
652+
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
652653
"""Read a resource by URI.
653654
654655
Args:

src/mcp/server/lowlevel/server.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ async def main():
6767
import contextvars
6868
import logging
6969
import warnings
70-
from collections.abc import Awaitable, Callable
70+
from collections.abc import Awaitable, Callable, Iterable
7171
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
72-
from typing import Any, AsyncIterator, Generic, Sequence, TypeVar
72+
from typing import Any, AsyncIterator, Generic, TypeVar
7373

7474
import anyio
7575
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -279,7 +279,9 @@ async def handler(_: Any):
279279

280280
def read_resource(self):
281281
def decorator(
282-
func: Callable[[AnyUrl], Awaitable[str | bytes | ReadResourceContents]],
282+
func: Callable[
283+
[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]
284+
],
283285
):
284286
logger.debug("Registering handler for ReadResourceRequest")
285287

@@ -307,13 +309,22 @@ def create_content(data: str | bytes, mime_type: str | None):
307309
case str() | bytes() as data:
308310
warnings.warn(
309311
"Returning str or bytes from read_resource is deprecated. "
310-
"Use ReadResourceContents instead.",
312+
"Use Iterable[ReadResourceContents] instead.",
311313
DeprecationWarning,
312314
stacklevel=2,
313315
)
314316
content = create_content(data, None)
315-
case ReadResourceContents() as contents:
316-
content = create_content(contents.content, contents.mime_type)
317+
case Iterable() as contents:
318+
contents_list = [
319+
create_content(content_item.content, content_item.mime_type)
320+
for content_item in contents
321+
if isinstance(content_item, ReadResourceContents)
322+
]
323+
return types.ServerResult(
324+
types.ReadResourceResult(
325+
contents=contents_list,
326+
)
327+
)
317328
case _:
318329
raise ValueError(
319330
f"Unexpected return type from read_resource: {type(result)}"
@@ -387,7 +398,7 @@ def decorator(
387398
func: Callable[
388399
...,
389400
Awaitable[
390-
Sequence[
401+
Iterable[
391402
types.TextContent | types.ImageContent | types.EmbeddedResource
392403
]
393404
],

src/mcp/shared/memory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import anyio
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111

12-
from mcp.client.session import ClientSession
12+
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
1313
from mcp.server import Server
1414
from mcp.types import JSONRPCMessage
1515

@@ -54,6 +54,8 @@ async def create_client_server_memory_streams() -> (
5454
async def create_connected_server_and_client_session(
5555
server: Server,
5656
read_timeout_seconds: timedelta | None = None,
57+
sampling_callback: SamplingFnT | None = None,
58+
list_roots_callback: ListRootsFnT | None = None,
5759
raise_exceptions: bool = False,
5860
) -> AsyncGenerator[ClientSession, None]:
5961
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -80,6 +82,8 @@ async def create_connected_server_and_client_session(
8082
read_stream=client_read,
8183
write_stream=client_write,
8284
read_timeout_seconds=read_timeout_seconds,
85+
sampling_callback=sampling_callback,
86+
list_roots_callback=list_roots_callback,
8387
) as client_session:
8488
await client_session.initialize()
8589
yield client_session
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
from pydantic import FileUrl
3+
4+
from mcp.client.session import ClientSession
5+
from mcp.server.fastmcp.server import Context
6+
from mcp.shared.context import RequestContext
7+
from mcp.shared.memory import (
8+
create_connected_server_and_client_session as create_session,
9+
)
10+
from mcp.types import (
11+
ListRootsResult,
12+
Root,
13+
TextContent,
14+
)
15+
16+
17+
@pytest.mark.anyio
18+
async def test_list_roots_callback():
19+
from mcp.server.fastmcp import FastMCP
20+
21+
server = FastMCP("test")
22+
23+
callback_return = ListRootsResult(
24+
roots=[
25+
Root(
26+
uri=FileUrl("file://users/fake/test"),
27+
name="Test Root 1",
28+
),
29+
Root(
30+
uri=FileUrl("file://users/fake/test/2"),
31+
name="Test Root 2",
32+
),
33+
]
34+
)
35+
36+
async def list_roots_callback(
37+
context: RequestContext[ClientSession, None],
38+
) -> ListRootsResult:
39+
return callback_return
40+
41+
@server.tool("test_list_roots")
42+
async def test_list_roots(context: Context, message: str):
43+
roots = await context.session.list_roots()
44+
assert roots == callback_return
45+
return True
46+
47+
# Test with list_roots callback
48+
async with create_session(
49+
server._mcp_server, list_roots_callback=list_roots_callback
50+
) as client_session:
51+
# Make a request to trigger sampling callback
52+
result = await client_session.call_tool(
53+
"test_list_roots", {"message": "test message"}
54+
)
55+
assert result.isError is False
56+
assert isinstance(result.content[0], TextContent)
57+
assert result.content[0].text == "true"
58+
59+
# Test without list_roots callback
60+
async with create_session(server._mcp_server) as client_session:
61+
# Make a request to trigger sampling callback
62+
result = await client_session.call_tool(
63+
"test_list_roots", {"message": "test message"}
64+
)
65+
assert result.isError is True
66+
assert isinstance(result.content[0], TextContent)
67+
assert (
68+
result.content[0].text
69+
== "Error executing tool test_list_roots: List roots not supported"
70+
)

0 commit comments

Comments
 (0)