File tree Expand file tree Collapse file tree 3 files changed +61
-5
lines changed Expand file tree Collapse file tree 3 files changed +61
-5
lines changed Original file line number Diff line number Diff line change 8
8
from mcp .shared .session import BaseSession , RequestResponder
9
9
from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
10
10
11
- sampling_function_signature = Callable [
11
+ SamplingFnT = Callable [
12
12
[types .CreateMessageRequestParams ], Awaitable [types .CreateMessageResult ]
13
13
]
14
14
@@ -22,14 +22,14 @@ class ClientSession(
22
22
types .ServerNotification ,
23
23
]
24
24
):
25
- sampling_callback : sampling_function_signature | None = None
25
+ sampling_callback : SamplingFnT | None = None
26
26
27
27
def __init__ (
28
28
self ,
29
29
read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
30
30
write_stream : MemoryObjectSendStream [types .JSONRPCMessage ],
31
31
read_timeout_seconds : timedelta | None = None ,
32
- sampling_callback : sampling_function_signature | None = None ,
32
+ sampling_callback : SamplingFnT | None = None ,
33
33
) -> None :
34
34
super ().__init__ (
35
35
read_stream ,
@@ -253,4 +253,5 @@ async def _received_request(
253
253
if self .sampling_callback is not None :
254
254
response = await self .sampling_callback (responder .request .root .params )
255
255
client_response = types .ClientResult (root = response )
256
- await responder .respond (client_response )
256
+ with responder :
257
+ await responder .respond (client_response )
Original file line number Diff line number Diff line change 9
9
import anyio
10
10
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
11
11
12
- from mcp .client .session import ClientSession
12
+ from mcp .client .session import ClientSession , SamplingFnT
13
13
from mcp .server import Server
14
14
from mcp .types import JSONRPCMessage
15
15
@@ -54,6 +54,7 @@ async def create_client_server_memory_streams() -> (
54
54
async def create_connected_server_and_client_session (
55
55
server : Server ,
56
56
read_timeout_seconds : timedelta | None = None ,
57
+ sampling_callback : SamplingFnT | None = None ,
57
58
raise_exceptions : bool = False ,
58
59
) -> AsyncGenerator [ClientSession , None ]:
59
60
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -80,6 +81,7 @@ async def create_connected_server_and_client_session(
80
81
read_stream = client_read ,
81
82
write_stream = client_write ,
82
83
read_timeout_seconds = read_timeout_seconds ,
84
+ sampling_callback = sampling_callback ,
83
85
) as client_session :
84
86
await client_session .initialize ()
85
87
yield client_session
Original file line number Diff line number Diff line change
1
+ import pytest
2
+
3
+ from mcp .shared .memory import (
4
+ create_connected_server_and_client_session as create_session ,
5
+ )
6
+ from mcp .types import (
7
+ CreateMessageRequestParams ,
8
+ CreateMessageResult ,
9
+ SamplingMessage ,
10
+ TextContent ,
11
+ )
12
+
13
+
14
+ @pytest .mark .anyio
15
+ async def test_sampling_callback ():
16
+ from mcp .server .fastmcp import FastMCP
17
+
18
+ server = FastMCP ("test" )
19
+
20
+ callback_return = CreateMessageResult (
21
+ role = "assistant" ,
22
+ content = TextContent (
23
+ type = "text" , text = "This is a response from the sampling callback"
24
+ ),
25
+ model = "test-model" ,
26
+ stopReason = "endTurn" ,
27
+ )
28
+
29
+ async def sampling_callback (
30
+ message : CreateMessageRequestParams ,
31
+ ) -> CreateMessageResult :
32
+ return callback_return
33
+
34
+ @server .tool ("test_sampling" )
35
+ async def test_sampling_tool (message : str ):
36
+ value = await server .get_context ().session .create_message (
37
+ messages = [
38
+ SamplingMessage (
39
+ role = "user" , content = TextContent (type = "text" , text = message )
40
+ )
41
+ ],
42
+ max_tokens = 100 ,
43
+ )
44
+ assert value == callback_return
45
+ return True
46
+
47
+ async with create_session (
48
+ server ._mcp_server , sampling_callback = sampling_callback
49
+ ) as client_session :
50
+ # Make a request to trigger sampling callback
51
+ assert await client_session .call_tool (
52
+ "test_sampling" , {"message" : "Test message for sampling" }
53
+ )
You can’t perform that action at this time.
0 commit comments