11from datetime import timedelta
2+ from inspect import iscoroutinefunction
3+ from typing import Awaitable , Callable
24
35from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
46from pydantic import AnyUrl
79from mcp .shared .session import BaseSession
810from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
911
12+ sampling_function_signature = Callable [
13+ [types .CreateMessageRequestParams ], Awaitable [types .CreateMessageResult ]
14+ ]
15+
1016
1117class ClientSession (
1218 BaseSession [
@@ -17,11 +23,14 @@ class ClientSession(
1723 types .ServerNotification ,
1824 ]
1925):
26+ sampling_callback : sampling_function_signature | None = None
27+
2028 def __init__ (
2129 self ,
2230 read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
2331 write_stream : MemoryObjectSendStream [types .JSONRPCMessage ],
2432 read_timeout_seconds : timedelta | None = None ,
33+ sampling_callback : sampling_function_signature | None = None ,
2534 ) -> None :
2635 super ().__init__ (
2736 read_stream ,
@@ -31,15 +40,29 @@ def __init__(
3140 read_timeout_seconds = read_timeout_seconds ,
3241 )
3342
43+ # validate sampling_callback
44+ # use asserts here because this should be known at compile time
45+ if sampling_callback is not None :
46+ assert callable (sampling_callback ), "sampling_callback must be callable"
47+ assert iscoroutinefunction (
48+ sampling_callback
49+ ), "sampling_callback must be an async function"
50+
51+ self .sampling_callback = sampling_callback
52+
3453 async def initialize (self ) -> types .InitializeResult :
54+ sampling = None
55+ if self .sampling_callback is not None :
56+ sampling = types .SamplingCapability ()
57+
3558 result = await self .send_request (
3659 types .ClientRequest (
3760 types .InitializeRequest (
3861 method = "initialize" ,
3962 params = types .InitializeRequestParams (
4063 protocolVersion = types .LATEST_PROTOCOL_VERSION ,
4164 capabilities = types .ClientCapabilities (
42- sampling = None ,
65+ sampling = sampling ,
4366 experimental = None ,
4467 roots = types .RootsCapability (
4568 # TODO: Should this be based on whether we
0 commit comments