File tree Expand file tree Collapse file tree 3 files changed +61
-5
lines changed
Expand file tree Collapse file tree 3 files changed +61
-5
lines changed Original file line number Diff line number Diff line change 88from mcp .shared .session import BaseSession , RequestResponder
99from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
1010
11- sampling_function_signature = Callable [
11+ SamplingFnT = Callable [
1212 [types .CreateMessageRequestParams ], Awaitable [types .CreateMessageResult ]
1313]
1414
@@ -22,14 +22,14 @@ class ClientSession(
2222 types .ServerNotification ,
2323 ]
2424):
25- sampling_callback : sampling_function_signature | None = None
25+ sampling_callback : SamplingFnT | None = None
2626
2727 def __init__ (
2828 self ,
2929 read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
3030 write_stream : MemoryObjectSendStream [types .JSONRPCMessage ],
3131 read_timeout_seconds : timedelta | None = None ,
32- sampling_callback : sampling_function_signature | None = None ,
32+ sampling_callback : SamplingFnT | None = None ,
3333 ) -> None :
3434 super ().__init__ (
3535 read_stream ,
@@ -253,4 +253,5 @@ async def _received_request(
253253 if self .sampling_callback is not None :
254254 response = await self .sampling_callback (responder .request .root .params )
255255 client_response = types .ClientResult (root = response )
256- await responder .respond (client_response )
256+ with responder :
257+ await responder .respond (client_response )
Original file line number Diff line number Diff line change 99import anyio
1010from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
1111
12- from mcp .client .session import ClientSession
12+ from mcp .client .session import ClientSession , SamplingFnT
1313from mcp .server import Server
1414from mcp .types import JSONRPCMessage
1515
@@ -54,6 +54,7 @@ async def create_client_server_memory_streams() -> (
5454async def create_connected_server_and_client_session (
5555 server : Server ,
5656 read_timeout_seconds : timedelta | None = None ,
57+ sampling_callback : SamplingFnT | None = None ,
5758 raise_exceptions : bool = False ,
5859) -> AsyncGenerator [ClientSession , None ]:
5960 """Creates a ClientSession that is connected to a running MCP server."""
@@ -80,6 +81,7 @@ async def create_connected_server_and_client_session(
8081 read_stream = client_read ,
8182 write_stream = client_write ,
8283 read_timeout_seconds = read_timeout_seconds ,
84+ sampling_callback = sampling_callback ,
8385 ) as client_session :
8486 await client_session .initialize ()
8587 yield client_session
Original file line number Diff line number Diff line change 1+ import pytest
2+
3+ from mcp .shared .memory import (
4+ create_connected_server_and_client_session as create_session ,
5+ )
6+ from mcp .types import (
7+ CreateMessageRequestParams ,
8+ CreateMessageResult ,
9+ SamplingMessage ,
10+ TextContent ,
11+ )
12+
13+
14+ @pytest .mark .anyio
15+ async def test_sampling_callback ():
16+ from mcp .server .fastmcp import FastMCP
17+
18+ server = FastMCP ("test" )
19+
20+ callback_return = CreateMessageResult (
21+ role = "assistant" ,
22+ content = TextContent (
23+ type = "text" , text = "This is a response from the sampling callback"
24+ ),
25+ model = "test-model" ,
26+ stopReason = "endTurn" ,
27+ )
28+
29+ async def sampling_callback (
30+ message : CreateMessageRequestParams ,
31+ ) -> CreateMessageResult :
32+ return callback_return
33+
34+ @server .tool ("test_sampling" )
35+ async def test_sampling_tool (message : str ):
36+ value = await server .get_context ().session .create_message (
37+ messages = [
38+ SamplingMessage (
39+ role = "user" , content = TextContent (type = "text" , text = message )
40+ )
41+ ],
42+ max_tokens = 100 ,
43+ )
44+ assert value == callback_return
45+ return True
46+
47+ async with create_session (
48+ server ._mcp_server , sampling_callback = sampling_callback
49+ ) as client_session :
50+ # Make a request to trigger sampling callback
51+ assert await client_session .call_tool (
52+ "test_sampling" , {"message" : "Test message for sampling" }
53+ )
You can’t perform that action at this time.
0 commit comments