1
1
from datetime import timedelta
2
+ from inspect import iscoroutinefunction
3
+ from typing import Awaitable , Callable
2
4
3
5
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
4
6
from pydantic import AnyUrl
7
9
from mcp .shared .session import BaseSession
8
10
from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
9
11
12
+ sampling_function_signature = Callable [
13
+ [types .CreateMessageRequestParams ], Awaitable [types .CreateMessageResult ]
14
+ ]
15
+
10
16
11
17
class ClientSession (
12
18
BaseSession [
@@ -17,11 +23,14 @@ class ClientSession(
17
23
types .ServerNotification ,
18
24
]
19
25
):
26
+ sampling_callback : sampling_function_signature | None = None
27
+
20
28
def __init__ (
21
29
self ,
22
30
read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
23
31
write_stream : MemoryObjectSendStream [types .JSONRPCMessage ],
24
32
read_timeout_seconds : timedelta | None = None ,
33
+ sampling_callback : sampling_function_signature | None = None ,
25
34
) -> None :
26
35
super ().__init__ (
27
36
read_stream ,
@@ -31,15 +40,29 @@ def __init__(
31
40
read_timeout_seconds = read_timeout_seconds ,
32
41
)
33
42
43
+ # validate sampling_callback
44
+ # use asserts here because this should be known at compile time
45
+ if sampling_callback is not None :
46
+ assert callable (sampling_callback ), "sampling_callback must be callable"
47
+ assert iscoroutinefunction (
48
+ sampling_callback
49
+ ), "sampling_callback must be an async function"
50
+
51
+ self .sampling_callback = sampling_callback
52
+
34
53
async def initialize (self ) -> types .InitializeResult :
54
+ sampling = None
55
+ if self .sampling_callback is not None :
56
+ sampling = types .SamplingCapability ()
57
+
35
58
result = await self .send_request (
36
59
types .ClientRequest (
37
60
types .InitializeRequest (
38
61
method = "initialize" ,
39
62
params = types .InitializeRequestParams (
40
63
protocolVersion = types .LATEST_PROTOCOL_VERSION ,
41
64
capabilities = types .ClientCapabilities (
42
- sampling = None ,
65
+ sampling = sampling ,
43
66
experimental = None ,
44
67
roots = types .RootsCapability (
45
68
# TODO: Should this be based on whether we
0 commit comments