|
1 | 1 | from datetime import timedelta |
2 | | -from inspect import iscoroutinefunction |
3 | 2 | from typing import Awaitable, Callable |
4 | 3 |
|
5 | 4 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
@@ -39,21 +38,12 @@ def __init__( |
39 | 38 | types.ServerNotification, |
40 | 39 | read_timeout_seconds=read_timeout_seconds, |
41 | 40 | ) |
42 | | - |
43 | | - # validate sampling_callback |
44 | | - # use asserts here because this should be known at compile time |
45 | | - if sampling_callback is not None: |
46 | | - assert callable(sampling_callback), "sampling_callback must be callable" |
47 | | - assert iscoroutinefunction( |
48 | | - sampling_callback |
49 | | - ), "sampling_callback must be an async function" |
50 | | - |
51 | 41 | self.sampling_callback = sampling_callback |
52 | 42 |
|
53 | 43 | async def initialize(self) -> types.InitializeResult: |
54 | | - sampling = None |
55 | | - if self.sampling_callback is not None: |
56 | | - sampling = types.SamplingCapability() |
| 44 | + sampling = ( |
| 45 | + types.SamplingCapability() if self.sampling_callback is not None else None |
| 46 | + ) |
57 | 47 |
|
58 | 48 | result = await self.send_request( |
59 | 49 | types.ClientRequest( |
@@ -260,11 +250,7 @@ async def _received_request( |
260 | 250 | self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] |
261 | 251 | ) -> None: |
262 | 252 | if isinstance(responder.request.root, types.CreateMessageRequest): |
263 | | - # handle create message request (sampling) |
264 | | - |
265 | | - if self.sampling_callback is None: |
266 | | - raise RuntimeError("Sampling callback is not set") |
267 | | - |
268 | | - response = await self.sampling_callback(responder.request.root.params) |
269 | | - client_response = types.ClientResult(**response.model_dump()) |
270 | | - await responder.respond(client_response) |
| 253 | + if self.sampling_callback is not None: |
| 254 | + response = await self.sampling_callback(responder.request.root.params) |
| 255 | + client_response = types.ClientResult(root=response) |
| 256 | + await responder.respond(client_response) |
0 commit comments