Skip to content

Commit cce8519

Browse files
add sampling callback paramater
1 parent 2efa525 commit cce8519

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

src/mcp/client/session.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from datetime import timedelta
2+
from inspect import iscoroutinefunction
3+
from typing import Awaitable, Callable
24

35
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
46
from pydantic import AnyUrl
@@ -7,6 +9,10 @@
79
from mcp.shared.session import BaseSession
810
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
911

12+
sampling_function_signature = Callable[
13+
[types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult]
14+
]
15+
1016

1117
class ClientSession(
1218
BaseSession[
@@ -17,11 +23,14 @@ class ClientSession(
1723
types.ServerNotification,
1824
]
1925
):
26+
sampling_callback: sampling_function_signature | None = None
27+
2028
def __init__(
2129
self,
2230
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
2331
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
2432
read_timeout_seconds: timedelta | None = None,
33+
sampling_callback: sampling_function_signature | None = None,
2534
) -> None:
2635
super().__init__(
2736
read_stream,
@@ -31,15 +40,29 @@ def __init__(
3140
read_timeout_seconds=read_timeout_seconds,
3241
)
3342

43+
# validate sampling_callback
44+
# use asserts here because this should be known at compile time
45+
if sampling_callback is not None:
46+
assert callable(sampling_callback), "sampling_callback must be callable"
47+
assert iscoroutinefunction(
48+
sampling_callback
49+
), "sampling_callback must be an async function"
50+
51+
self.sampling_callback = sampling_callback
52+
3453
async def initialize(self) -> types.InitializeResult:
54+
sampling = None
55+
if self.sampling_callback is not None:
56+
sampling = types.SamplingCapability()
57+
3558
result = await self.send_request(
3659
types.ClientRequest(
3760
types.InitializeRequest(
3861
method="initialize",
3962
params=types.InitializeRequestParams(
4063
protocolVersion=types.LATEST_PROTOCOL_VERSION,
4164
capabilities=types.ClientCapabilities(
42-
sampling=None,
65+
sampling=sampling,
4366
experimental=None,
4467
roots=types.RootsCapability(
4568
# TODO: Should this be based on whether we

0 commit comments

Comments
 (0)