Skip to content

Commit 8f0f7c5

Browse files
committed
fix: simplify implementation and add test
1 parent e2e2f43 commit 8f0f7c5

File tree

3 files changed

+61
-5
lines changed

3 files changed

+61
-5
lines changed

src/mcp/client/session.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mcp.shared.session import BaseSession, RequestResponder
99
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1010

11-
sampling_function_signature = Callable[
11+
SamplingFnT = Callable[
1212
[types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult]
1313
]
1414

@@ -22,14 +22,14 @@ class ClientSession(
2222
types.ServerNotification,
2323
]
2424
):
25-
sampling_callback: sampling_function_signature | None = None
25+
sampling_callback: SamplingFnT | None = None
2626

2727
def __init__(
2828
self,
2929
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
3030
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
3131
read_timeout_seconds: timedelta | None = None,
32-
sampling_callback: sampling_function_signature | None = None,
32+
sampling_callback: SamplingFnT | None = None,
3333
) -> None:
3434
super().__init__(
3535
read_stream,
@@ -253,4 +253,5 @@ async def _received_request(
253253
if self.sampling_callback is not None:
254254
response = await self.sampling_callback(responder.request.root.params)
255255
client_response = types.ClientResult(root=response)
256-
await responder.respond(client_response)
256+
with responder:
257+
await responder.respond(client_response)

src/mcp/shared/memory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import anyio
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111

12-
from mcp.client.session import ClientSession
12+
from mcp.client.session import ClientSession, SamplingFnT
1313
from mcp.server import Server
1414
from mcp.types import JSONRPCMessage
1515

@@ -54,6 +54,7 @@ async def create_client_server_memory_streams() -> (
5454
async def create_connected_server_and_client_session(
5555
server: Server,
5656
read_timeout_seconds: timedelta | None = None,
57+
sampling_callback: SamplingFnT | None = None,
5758
raise_exceptions: bool = False,
5859
) -> AsyncGenerator[ClientSession, None]:
5960
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -80,6 +81,7 @@ async def create_connected_server_and_client_session(
8081
read_stream=client_read,
8182
write_stream=client_write,
8283
read_timeout_seconds=read_timeout_seconds,
84+
sampling_callback=sampling_callback,
8385
) as client_session:
8486
await client_session.initialize()
8587
yield client_session
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
3+
from mcp.shared.memory import (
4+
create_connected_server_and_client_session as create_session,
5+
)
6+
from mcp.types import (
7+
CreateMessageRequestParams,
8+
CreateMessageResult,
9+
SamplingMessage,
10+
TextContent,
11+
)
12+
13+
14+
@pytest.mark.anyio
15+
async def test_sampling_callback():
16+
from mcp.server.fastmcp import FastMCP
17+
18+
server = FastMCP("test")
19+
20+
callback_return = CreateMessageResult(
21+
role="assistant",
22+
content=TextContent(
23+
type="text", text="This is a response from the sampling callback"
24+
),
25+
model="test-model",
26+
stopReason="endTurn",
27+
)
28+
29+
async def sampling_callback(
30+
message: CreateMessageRequestParams,
31+
) -> CreateMessageResult:
32+
return callback_return
33+
34+
@server.tool("test_sampling")
35+
async def test_sampling_tool(message: str):
36+
value = await server.get_context().session.create_message(
37+
messages=[
38+
SamplingMessage(
39+
role="user", content=TextContent(type="text", text=message)
40+
)
41+
],
42+
max_tokens=100,
43+
)
44+
assert value == callback_return
45+
return True
46+
47+
async with create_session(
48+
server._mcp_server, sampling_callback=sampling_callback
49+
) as client_session:
50+
# Make a request to trigger sampling callback
51+
assert await client_session.call_tool(
52+
"test_sampling", {"message": "Test message for sampling"}
53+
)

0 commit comments

Comments
 (0)