Skip to content

Commit e2e2f43

Browse files
committed
simplify the implementation
1 parent 38f639c commit e2e2f43

File tree

1 file changed

+7
-21
lines changed

1 file changed

+7
-21
lines changed

src/mcp/client/session.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from datetime import timedelta
2-
from inspect import iscoroutinefunction
32
from typing import Awaitable, Callable
43

54
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
@@ -39,21 +38,12 @@ def __init__(
3938
types.ServerNotification,
4039
read_timeout_seconds=read_timeout_seconds,
4140
)
42-
43-
# validate sampling_callback
44-
# use asserts here because this should be known at compile time
45-
if sampling_callback is not None:
46-
assert callable(sampling_callback), "sampling_callback must be callable"
47-
assert iscoroutinefunction(
48-
sampling_callback
49-
), "sampling_callback must be an async function"
50-
5141
self.sampling_callback = sampling_callback
5242

5343
async def initialize(self) -> types.InitializeResult:
54-
sampling = None
55-
if self.sampling_callback is not None:
56-
sampling = types.SamplingCapability()
44+
sampling = (
45+
types.SamplingCapability() if self.sampling_callback is not None else None
46+
)
5747

5848
result = await self.send_request(
5949
types.ClientRequest(
@@ -260,11 +250,7 @@ async def _received_request(
260250
self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"]
261251
) -> None:
262252
if isinstance(responder.request.root, types.CreateMessageRequest):
263-
# handle create message request (sampling)
264-
265-
if self.sampling_callback is None:
266-
raise RuntimeError("Sampling callback is not set")
267-
268-
response = await self.sampling_callback(responder.request.root.params)
269-
client_response = types.ClientResult(**response.model_dump())
270-
await responder.respond(client_response)
253+
if self.sampling_callback is not None:
254+
response = await self.sampling_callback(responder.request.root.params)
255+
client_response = types.ClientResult(root=response)
256+
await responder.respond(client_response)

0 commit comments

Comments
 (0)