diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 1853ce7c1..049fa982d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,8 +9,8 @@ import mcp.types as types from mcp.shared.context import RequestContext -from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, RequestStateManager from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -118,6 +118,7 @@ def __init__( logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, + request_state_manager: RequestStateManager[types.ClientRequest, types.ClientResult] | None = None, ) -> None: super().__init__( read_stream, @@ -125,6 +126,7 @@ def __init__( types.ServerRequest, types.ServerNotification, read_timeout_seconds=read_timeout_seconds, + request_state_manager=request_state_manager, ) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback @@ -133,6 +135,7 @@ def __init__( self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} + self._resumable = False async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None @@ -170,6 +173,8 @@ async def initialize(self) -> types.InitializeResult: if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}") + self._resumable = result.capabilities.resume and result.capabilities.resume.resumable + await self.send_notification( types.ClientNotification(types.InitializedNotification(method="notifications/initialized")) ) @@ -281,6 +286,78 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: types.EmptyResult, ) + async def request_call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, + ) -> types.RequestId: + if self._resumable: + captured_token = None + captured = anyio.Event() + + async def capture_token(token: str): + nonlocal captured_token + captured_token = token + captured.set() + + metadata = ClientMessageMetadata(on_resumption_token_update=capture_token) + + request_id = await self.start_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + ), + ) + ), + progress_callback=progress_callback, + metadata=metadata, + ) + + while captured_token is None: + await captured.wait() + + await self._request_state_manager.update_resume_token(request_id, captured_token) + + return request_id + else: + return await self.start_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + ), + ) + ), + progress_callback=progress_callback, + ) + + async def join_call_tool( + self, + request_id: types.RequestId, + progress_callback: ProgressFnT | None = None, + request_read_timeout_seconds: timedelta | None = None, + done_on_timeout: bool = True, + ) -> types.CallToolResult | None: + return await self.join_request( + request_id, + types.CallToolResult, + request_read_timeout_seconds=request_read_timeout_seconds, + progress_callback=progress_callback, + done_on_timeout=done_on_timeout, + ) + + async def cancel_call_tool( + self, + request_id: types.RequestId, + ) -> bool: + return await self.cancel_request(request_id) + async def call_tool( self, name: str, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 63b09133f..99ee611f5 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -29,6 +29,7 @@ JSONRPCRequest, JSONRPCResponse, RequestId, + ResumeCapability, ) logger = logging.getLogger(__name__) @@ -136,7 +137,7 @@ def _maybe_extract_session_id_from_response( def _maybe_extract_protocol_version_from_message( self, message: JSONRPCMessage, - ) -> None: + ) -> JSONRPCMessage: """Extract protocol version from initialization response message.""" if isinstance(message.root, JSONRPCResponse) and message.root.result: try: @@ -144,10 +145,18 @@ def _maybe_extract_protocol_version_from_message( init_result = InitializeResult.model_validate(message.root.result) self.protocol_version = str(init_result.protocolVersion) logger.info(f"Negotiated protocol version: {self.protocol_version}") + if init_result.capabilities.resume is None: + # resumeablity is predicated on the server and the transport + # this assumes that if the server hasn't explicitly configured + # that streamable http transports are resumeable + init_result.capabilities.resume = ResumeCapability(resumable=True) + message.root.result = init_result.model_dump() except Exception as exc: logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}") logger.warning(f"Raw result: {message.root.result}") + return message + async def _handle_sse_event( self, sse: ServerSentEvent, @@ -164,7 +173,7 @@ async def _handle_sse_event( # Extract protocol version from initialization response if is_initialization: - self._maybe_extract_protocol_version_from_message(message) + message = self._maybe_extract_protocol_version_from_message(message) # If this is a response and we have original_request_id, replace it if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): @@ -303,7 +312,7 @@ async def _handle_json_response( # Extract protocol version from initialization response if is_initialization: - self._maybe_extract_protocol_version_from_message(message) + message = self._maybe_extract_protocol_version_from_message(message) session_message = SessionMessage(message) await read_stream_writer.send(session_message) @@ -333,7 +342,10 @@ async def _handle_sse_response( break except Exception as e: logger.exception("Error reading SSE stream:") - await ctx.read_stream_writer.send(e) + try: + await ctx.read_stream_writer.send(e) + except anyio.ClosedResourceError: + pass async def _handle_unexpected_content_type( self, @@ -471,8 +483,8 @@ async def streamablehttp_client( read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - async with anyio.create_task_group() as tg: - try: + try: + async with anyio.create_task_group() as tg: logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") async with httpx_client_factory( @@ -504,6 +516,6 @@ def start_get_stream() -> None: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 802cb8680..9978e4211 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -767,7 +767,6 @@ async def send_event(event_message: EventMessage) -> None: async with msg_reader: async for event_message in msg_reader: event_data = self._create_event_data(event_message) - await sse_stream_writer.send(event_data) except Exception: logger.exception("Error in replay sender") diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2f49fc8b..60ed1387e 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,11 +12,12 @@ from typing_extensions import Self from mcp.shared.exceptions import McpError -from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( CONNECTION_CLOSED, INVALID_PARAMS, CancelledNotification, + CancelledNotificationParams, ClientNotification, ClientRequest, ClientResult, @@ -26,6 +27,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + PingRequest, ProgressNotification, RequestParams, ServerNotification, @@ -156,6 +158,190 @@ def cancelled(self) -> bool: return self._cancel_scope.cancel_called +class RequestStateManager( + Generic[ + SendRequestT, + SendResultT, + ], +): + def new_request(self, request: SendRequestT) -> RequestId: ... + + def resume(self, request_id: RequestId) -> bool: ... + + async def update_resume_token(self, request_id: RequestId, token: str) -> None: ... + + async def get_resume_token(self, request_id: RequestId) -> str | None: ... + + def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): ... + + async def send_progress( + self, + request_id: RequestId, + progress: float, + total: float | None, + message: str | None, + ): ... + + async def receive_response( + self, + request_id: RequestId, + timeout: float | None = None, + ) -> JSONRPCResponse | JSONRPCError | None: ... + + async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: ... + + async def close_request(self, request_id: RequestId) -> bool: ... + + async def close(self) -> None: ... + + +class InMemoryRequestStateManager( + RequestStateManager[ + SendRequestT, + SendResultT, + ], +): + _request_id: int + _requests: dict[ + RequestId, + SendRequestT, + ] + _response_streams: dict[ + RequestId, + tuple[ + MemoryObjectSendStream[JSONRPCResponse | JSONRPCError], + MemoryObjectReceiveStream[JSONRPCResponse | JSONRPCError], + ], + ] + _progress_callbacks: dict[RequestId, list[ProgressFnT]] + _resume_tokens: dict[RequestId, str] + + def __init__(self): + self._request_id = 0 + self._requests = {} + self._response_streams = {} + self._progress_callbacks = {} + self._resume_tokens = {} + + def new_request(self, request: SendRequestT) -> RequestId: + request_id = self._request_id + self._request_id = request_id + 1 + + send_stream, receive_stream = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) + self._response_streams[request_id] = send_stream, receive_stream + self._requests[request_id] = request + + return request_id + + def resume(self, request_id: RequestId) -> bool: + if self._requests.get(request_id) is None: + raise RuntimeError(f"Unknown request {request_id}") + + if request_id in self._response_streams: + return False + else: + send_stream, receive_stream = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) + self._response_streams[request_id] = send_stream, receive_stream + return True + + async def update_resume_token(self, request_id: RequestId, token: str) -> None: + self._resume_tokens[request_id] = token + + async def get_resume_token(self, request_id: RequestId) -> str | None: + return self._resume_tokens.get(request_id) + + def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): + progress_list = self._progress_callbacks.get(request_id) + if progress_list is None: + progress_list = [] + self._progress_callbacks[request_id] = progress_list + + progress_list.append(progress_callback) + + async def send_progress( + self, + request_id: RequestId, + progress: float, + total: float | None, + message: str | None, + ): + if request_id in self._progress_callbacks: + callbacks = self._progress_callbacks[request_id] + for callback in callbacks: + await callback( + progress, + total, + message, + ) + + async def receive_response( + self, + request_id: RequestId, + timeout: float | None = None, + ) -> JSONRPCResponse | JSONRPCError | None: + _, receive_stream = self._response_streams.get(request_id, [None, None]) + if receive_stream is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message=(f"Unknown request {request_id}"), + ) + ) + + request = self._requests.get(request_id, None) + assert request is not None + + try: + with anyio.fail_after(timeout): + return await receive_stream.receive() + except anyio.EndOfStream: + raise McpError( + ErrorData( + code=CONNECTION_CLOSED, + message=("Connection closed"), + ) + ) + except TimeoutError: + return None + + async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: + send_stream, _ = self._response_streams.get(message.id, [None, None]) + if send_stream: + await send_stream.send(message) + return True + else: + return False + + async def close_request(self, request_id: RequestId) -> bool: + send_stream, receive_stream = self._response_streams.pop(request_id, [None, None]) + if send_stream is not None: + await send_stream.aclose() + if receive_stream is not None: + await receive_stream.aclose() + + self._requests.pop(request_id, None) + self._resume_tokens.pop(request_id, None) + self._progress_callbacks.pop(request_id, None) + + return send_stream is not None + + async def close(self): + for id, [send_stream, receive_stream] in self._response_streams.copy().items(): + await receive_stream.aclose() + try: + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + await send_stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + except anyio.BrokenResourceError: + # Stream already be closed + pass + except anyio.ClosedResourceError: + # Stream already be closed + pass + finally: + await send_stream.aclose() + self._response_streams.pop(id) + + class BaseSession( Generic[ SendRequestT, @@ -173,10 +359,7 @@ class BaseSession( messages when entered. """ - _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] - _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] - _progress_callbacks: dict[RequestId, ProgressFnT] def __init__( self, @@ -186,17 +369,16 @@ def __init__( receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: timedelta | None = None, + request_state_manager: RequestStateManager[SendRequestT, SendResultT] | None = None, ) -> None: self._read_stream = read_stream self._write_stream = write_stream - self._response_streams = {} - self._request_id = 0 self._receive_request_type = receive_request_type self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds - self._in_flight = {} - self._progress_callbacks = {} self._exit_stack = AsyncExitStack() + self._in_flight = {} + self._request_state_manager = request_state_manager or InMemoryRequestStateManager() async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() @@ -217,28 +399,19 @@ async def __aexit__( self._task_group.cancel_scope.cancel() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - async def send_request( + async def start_request( self, request: SendRequestT, - result_type: type[ReceiveResultT], - request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, - ) -> ReceiveResultT: + ) -> RequestId: """ - Sends a request and wait for a response. Raises an McpError if the - response contains an error. If a request read timeout is provided, it - will take precedence over the session read timeout. + Starts a request. Do not use this method to emit notifications! Use send_notification() instead. """ - request_id = self._request_id - self._request_id = request_id + 1 - - response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) - self._response_streams[request_id] = response_stream - + request_id = self._request_state_manager.new_request(request) # Set up progress token if progress callback is provided request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) if progress_callback is not None: @@ -249,49 +422,131 @@ async def send_request( request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback + self._request_state_manager.add_progress_callback(request_id, progress_callback) - try: - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request_data, - ) + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request_data, + ) + try: await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) + return request_id + except Exception as e: + await self._request_state_manager.close_request(request_id) + raise e - # request read timeout takes precedence over session read timeout - timeout = None - if request_read_timeout_seconds is not None: - timeout = request_read_timeout_seconds.total_seconds() - elif self._session_read_timeout_seconds is not None: - timeout = self._session_read_timeout_seconds.total_seconds() + async def join_request( + self, + request_id: RequestId, + result_type: type[ReceiveResultT], + request_read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + done_on_timeout: bool = True, + ) -> ReceiveResultT | None: + """ + Joins a request previously started via start_request. - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: + Returns the result or None if timeout is reached. + """ + resume = self._request_state_manager.resume(request_id) + + if progress_callback is not None: + self._request_state_manager.add_progress_callback(request_id, progress_callback) + + # request read timeout takes precedence over session read timeout + timeout = None + if request_read_timeout_seconds is not None: + timeout = request_read_timeout_seconds.total_seconds() + elif self._session_read_timeout_seconds is not None: + timeout = self._session_read_timeout_seconds.total_seconds() + + if resume: + resume_token = await self._request_state_manager.get_resume_token(request_id) + if resume_token is not None: + metadata = ClientMessageMetadata(resumption_token=resume_token) + + request_data = PingRequest(method="ping").model_dump(by_alias=True, mode="json", exclude_none=True) + + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request_data, + ) + + await self._write_stream.send( + SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata) + ) + + response_or_error = await self._request_state_manager.receive_response(request_id, timeout) + + if response_or_error is None: + if done_on_timeout: + await self._request_state_manager.close_request(request_id) + return None + elif isinstance(response_or_error, JSONRPCError): + if response_or_error.error.code == httpx.codes.REQUEST_TIMEOUT.value: + if done_on_timeout: + await self._request_state_manager.close_request(request_id) + return None + else: + await self._request_state_manager.close_request(request_id) + raise McpError(response_or_error.error) + else: + await self._request_state_manager.close_request(request_id) + return result_type.model_validate(response_or_error.result) + + async def cancel_request(self, request_id: RequestId) -> bool: + """ + Cancels a request previously started via start_request + """ + closed = await self._request_state_manager.close_request(request_id) + + if closed: + notification = CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams(requestId=request_id, reason="cancelled"), + ) + await self.send_notification(notification, request_id) # type: ignore + return True + else: + return False + + async def send_request( + self, + request: SendRequestT, + result_type: type[ReceiveResultT], + request_read_timeout_seconds: timedelta | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + ) -> ReceiveResultT: + """ + Sends a request and wait for a response. Raises an McpError if the + response contains an error. If a request read timeout is provided, it + will take precedence over the session read timeout. + + Do not use this method to emit notifications! Use send_notification() + instead. + """ + request_id = await self.start_request(request, metadata, progress_callback) + try: + result = await self.join_request(request_id, result_type, request_read_timeout_seconds) + if result is None: raise McpError( ErrorData( code=httpx.codes.REQUEST_TIMEOUT, message=( f"Timed out while waiting for response to " f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." + f"{request_read_timeout_seconds} seconds." ), ) ) - - if isinstance(response_or_error, JSONRPCError): - raise McpError(response_or_error.error) else: - return result_type.model_validate(response_or_error.result) - + return result finally: - self._response_streams.pop(request_id, None) - self._progress_callbacks.pop(request_id, None) - await response_stream.aclose() - await response_stream_reader.aclose() + await self._request_state_manager.close_request(request_id) async def send_notification( self, @@ -390,13 +645,12 @@ async def _receive_loop(self) -> None: progress_token = notification.root.params.progressToken # If there is a progress callback for this token, # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) + await self._request_state_manager.send_progress( + progress_token, + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: @@ -405,10 +659,8 @@ async def _receive_loop(self) -> None: f"Failed to validate notification: {e}. Message was: {message.message.root}" ) else: # Response or error - stream = self._response_streams.pop(message.message.root.id, None) - if stream: - await stream.send(message.message.root) - else: + handled = await self._request_state_manager.handle_response(message.message.root) + if not handled: await self._handle_incoming( RuntimeError(f"Received response with an unknown request ID: {message}") ) @@ -425,15 +677,7 @@ async def _receive_loop(self) -> None: finally: # after the read stream is closed, we need to send errors # to any pending requests - for id, stream in self._response_streams.items(): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - try: - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - except Exception: - # Stream might already be closed - pass - self._response_streams.clear() + await self._request_state_manager.close() async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ diff --git a/src/mcp/types.py b/src/mcp/types.py index 98fefa080..4ca06a20a 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -262,6 +262,14 @@ class PromptsCapability(BaseModel): model_config = ConfigDict(extra="allow") +class ResumeCapability(BaseModel): + """Capability for resume operations.""" + + resumable: bool | None = None + """Whether this server supports resume operations.""" + model_config = ConfigDict(extra="allow") + + class ResourcesCapability(BaseModel): """Capability for resources operations.""" @@ -303,6 +311,8 @@ class ServerCapabilities(BaseModel): """Present if the server offers any prompt templates.""" resources: ResourcesCapability | None = None """Present if the server offers any resources to read.""" + resume: ResumeCapability | None = None + """Present if the server offers resume capability.""" tools: ToolsCapability | None = None """Present if the server offers any tools to call.""" completions: CompletionsCapability | None = None diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py index 527884219..2d4c18343 100644 --- a/tests/client/test_resource_cleanup.py +++ b/tests/client/test_resource_cleanup.py @@ -3,7 +3,7 @@ import anyio import pytest -from mcp.shared.session import BaseSession +from mcp.shared.session import BaseSession, InMemoryRequestStateManager from mcp.types import ( ClientRequest, EmptyResult, @@ -28,12 +28,14 @@ async def _send_response(self, request_id, response): write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1) read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1) + request_io_manager = InMemoryRequestStateManager() # Create the session session = TestSession( read_stream_receive, write_stream_send, object, # Request type doesn't matter for this test - object, # Notification type doesn't matter for this test + object, # Notification type doesn't matter for this test, + request_state_manager=request_io_manager, ) # Create a test request @@ -48,7 +50,7 @@ async def mock_send(*args, **kwargs): raise RuntimeError("Simulated network error") # Record the response streams before the test - initial_stream_count = len(session._response_streams) + initial_stream_count = len(request_io_manager._response_streams) # Run the test with the patched method with patch.object(session._write_stream, "send", mock_send): @@ -56,8 +58,9 @@ async def mock_send(*args, **kwargs): await session.send_request(request, EmptyResult) # Verify that no response streams were leaked - assert len(session._response_streams) == initial_stream_count, ( - f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}" + assert len(request_io_manager._response_streams) == initial_stream_count, ( + f"Expected {initial_stream_count} response streams after request, " + "but found {len(request_io_manager._response_streams)}" ) # Clean up diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 327d1a9e4..39bed193c 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,3 +1,4 @@ +from datetime import timedelta from typing import Any import anyio @@ -7,10 +8,13 @@ from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.session import InMemoryRequestStateManager, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, + CallToolRequest, + CallToolResult, + CancelledNotification, ClientNotification, ClientRequest, Implementation, @@ -23,6 +27,7 @@ JSONRPCResponse, ServerCapabilities, ServerResult, + TextContent, ) @@ -495,3 +500,491 @@ async def mock_server(): assert received_capabilities.roots is not None # Custom list_roots callback provided assert isinstance(received_capabilities.roots, types.RootsCapability) assert received_capabilities.roots.listChanged is True # Should be True for custom callback + + +@pytest.mark.anyio +async def test_client_session_request_call_tool(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + call_id = jsonrpc_request.root.id + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + request = request.root + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + name = request.params.arguments["name"] + + async with server_to_client_send: + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=call_id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + request_id = await session.request_call_tool("hello", {"name": "world"}) + with anyio.fail_after(1): + result = await session.join_call_tool(request_id) + + # Assert the result + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "hello world" + + +@pytest.mark.anyio +async def test_client_session_request_call_tool_join_timeout(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + send_result = anyio.Event() + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + call_id = jsonrpc_request.root.id + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + request = request.root + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + name = request.params.arguments["name"] + + await send_result.wait() + + async with server_to_client_send: + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=call_id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + request_id = await session.request_call_tool("hello", {"name": "world"}) + + with anyio.fail_after(3): + result = await session.join_call_tool( + request_id, request_read_timeout_seconds=timedelta(seconds=0.5), done_on_timeout=False + ) + assert result is None + send_result.set() + result = await session.join_call_tool( + request_id, request_read_timeout_seconds=timedelta(seconds=1), done_on_timeout=False + ) + + # Assert the result + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "hello world" + + # Assert resources tidied up + assert len(session._request_state_manager._response_streams) == 0 # type: ignore + + +@pytest.mark.anyio +async def test_client_session_request_call_tool_with_progress(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + send_progress_2 = anyio.Event() + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + call_id = jsonrpc_request.root.id + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + request = request.root + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + name = request.params.arguments["name"] + assert request.params.meta is not None + assert request.params.meta.progressToken is not None + progrss_token = request.params.meta.progressToken + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=progrss_token, + progress=1, + total=2, + message="event 1", + ).model_dump(), + ) + ) + ) + ) + + # await send_progress_2.wait() + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=progrss_token, + progress=2, + total=2, + message="event 2", + ).model_dump(), + ) + ) + ) + ) + + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=call_id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + progress_1 = anyio.Event() + progress_2 = anyio.Event() + + async def progress_callback1(progress: float, total: float | None, message: str | None) -> None: + if progress == 1: + progress_1.set() + elif progress == 2: + progress_2.set() + else: + raise RuntimeError("Unexpected progress value") + + request_id = await session.request_call_tool("hello", {"name": "world"}, progress_callback1) + + with anyio.fail_after(3): + await progress_1.wait() + result = await session.join_call_tool(request_id) + send_progress_2.set() + await progress_2.wait() + + # Assert the result + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "hello world" + + +@pytest.mark.anyio +async def test_client_session_request_call_tool_with_rejoin(): + client_1_to_server_send, client_1_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_1_send, server_to_client_1_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_2_to_server_send, client_2_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_2_send, server_to_client_2_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async def mock_server(): + session_message = await client_1_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + call_tool_id = jsonrpc_request.root.id + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + + request = request.root + + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + name = request.params.arguments["name"] + assert request.params.meta is not None + assert request.params.meta.progressToken is not None + progress_token = request.params.meta.progressToken + + async with server_to_client_1_send, server_to_client_2_send: + await server_to_client_1_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=progress_token, + progress=1, + total=2, + message="event 1", + ).model_dump(), + ) + ) + ) + ) + + await server_to_client_2_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=progress_token, + progress=2, + total=2, + message="event 2", + ).model_dump(), + ) + ) + ) + ) + + result = ServerResult(CallToolResult(content=[TextContent(type="text", text=f"hello {name}")])) + + await server_to_client_2_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=call_tool_id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + request_state_manager_1 = InMemoryRequestStateManager() + request_state_manager_2 = InMemoryRequestStateManager() + + async with ( + ClientSession( + server_to_client_1_receive, + client_1_to_server_send, + message_handler=message_handler, + request_state_manager=request_state_manager_1, + ) as session1, + ClientSession( + server_to_client_2_receive, + client_2_to_server_send, + message_handler=message_handler, + request_state_manager=request_state_manager_2, + ) as session2, + anyio.create_task_group() as tg, + client_1_to_server_send, + client_1_to_server_receive, + server_to_client_1_send, + server_to_client_1_receive, + client_2_to_server_send, + client_2_to_server_receive, + server_to_client_2_send, + server_to_client_2_receive, + ): + tg.start_soon(mock_server) + + progress_1_1 = anyio.Event() + progress_1_2 = anyio.Event() + progress_2_1 = anyio.Event() + progress_2_2 = anyio.Event() + + async def progress_callback1(progress: float, total: float | None, message: str | None) -> None: + if progress == 1: + progress_1_1.set() + elif progress == 2: + progress_1_2.set() + else: + raise RuntimeError("Unexpected progress value") + + async def progress_callback2(progress: float, total: float | None, message: str | None) -> None: + if progress == 1: + progress_2_1.set() + elif progress == 2: + progress_2_2.set() + else: + raise RuntimeError("Unexpected progress value") + + request_id = await session1.request_call_tool("hello", {"name": "world"}, progress_callback1) + with anyio.fail_after(1): + await progress_1_1.wait() + + # initialise io manager 2 to state of io manager 1 + request_state_manager_2._requests = request_state_manager_1._requests.copy() + + # simulate network disconnect and rejoin + await request_state_manager_1.close_request(request_id) + result = await session2.join_call_tool(request_id, progress_callback2) + + await progress_2_2.wait() + + assert not progress_1_2.is_set() + assert not progress_2_1.is_set() + # Assert the result + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "hello world" + + +@pytest.mark.anyio +async def test_client_session_cancel_call_tool(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + cancelled = anyio.Event() + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, CallToolRequest) + + request = request.root + assert "hello" == request.params.name + assert request.params.arguments is not None + assert "name" in request.params.arguments + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCNotification) + notification = ClientNotification.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(notification.root, CancelledNotification) + cancelled.set() + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: + pass + + request_id = await session.request_call_tool("hello", {"name": "world"}, progress_callback) + assert await session.cancel_call_tool(request_id) + with anyio.fail_after(1): + await cancelled.wait() + assert not await session.cancel_call_tool(request_id) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3fea54f0b..59c4388ec 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -9,6 +9,7 @@ import socket import time from collections.abc import Generator +from datetime import timedelta from typing import Any import anyio @@ -43,7 +44,7 @@ from mcp.shared.message import ( ClientMessageMetadata, ) -from mcp.shared.session import RequestResponder +from mcp.shared.session import InMemoryRequestStateManager, RequestResponder from mcp.types import ( InitializeResult, TextContent, @@ -1214,6 +1215,275 @@ async def run_tool(): assert captured_notifications[0].root.params.data == "Second notification after lock" +@pytest.mark.anyio +async def test_streamablehttp_client_resumption_non_blocking(event_server): + """Test client session to resume a long running tool via non blocking api.""" + _, server_url = event_server + + with anyio.fail_after(10): + # Variables to track the state + captured_session_id = None + captured_notifications = [] + tool_started = False + captured_protocol_version = None + captured_request_id = None + request_state_manager_1 = InMemoryRequestStateManager() + request_state_manager_2 = InMemoryRequestStateManager() + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + captured_notifications.append(message) + # Look for our special notification that indicates the tool is running + if isinstance(message.root, types.LoggingMessageNotification): + if message.root.params.data == "Tool started": + nonlocal tool_started + tool_started = True + + # First, start the client session and begin the long-running tool + async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + request_state_manager=request_state_manager_1, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None + # Capture the negotiated protocol version + captured_protocol_version = result.protocolVersion + + # Start a long-running tool in a task + async with anyio.create_task_group() as tg: + + async def run_tool(): + nonlocal captured_request_id + captured_request_id = await session.request_call_tool( + "long_running_with_checkpoints", arguments={} + ) + + tg.start_soon(run_tool) + + # Wait for the tool to start and at least one notification + # and then kill the task group + while ( + not tool_started or not captured_request_id or len(request_state_manager_1._resume_tokens) == 0 + ): + await anyio.sleep(0.1) + + tg.cancel_scope.cancel() + + # Store pre notifications and clear the captured notifications + # for the post-resumption check + captured_notifications_pre = captured_notifications.copy() + captured_notifications = [] + + # Now resume the session with the same mcp-session-id and protocol version + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id + if captured_protocol_version: + headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version + + assert len(request_state_manager_1._requests) == 1, str(request_state_manager_1._requests) + assert len(request_state_manager_1._resume_tokens) == 1 + + request_state_manager_2._requests = request_state_manager_1._requests.copy() + request_state_manager_2._resume_tokens = request_state_manager_1._resume_tokens.copy() + + async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + request_state_manager=request_state_manager_2, + ) as session: + # Don't initialize - just use the existing session + + # Resume the tool with the resumption token + assert captured_request_id is not None + + result = await session.join_call_tool(captured_request_id) + assert result is not None + + # We should get a complete result + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert "Completed" in result.content[0].text + + # We should have received the remaining notifications + assert len(captured_notifications) > 0 + + # Should not have the first notification + # Check that "Tool started" notification isn't repeated when resuming + assert not any( + isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started" + for n in captured_notifications + ) + # there is no intersection between pre and post notifications + assert not any(n in captured_notifications_pre for n in captured_notifications) + + assert len(request_state_manager_1._progress_callbacks) == 0 + assert len(request_state_manager_1._response_streams) == 0 + assert len(request_state_manager_2._progress_callbacks) == 0 + assert len(request_state_manager_2._resume_tokens) == 0 + assert len(request_state_manager_2._response_streams) == 0 + + +@pytest.mark.anyio +async def test_streamablehttp_client_resumption_timeout(event_server): + """Test client session to resume a long running tool via non blocking api with timeout.""" + _, server_url = event_server + + with anyio.fail_after(10): + # Variables to track the state + captured_session_id = None + captured_notifications = [] + tool_started = False + captured_protocol_version = None + captured_request_id = None + request_state_manager_1 = InMemoryRequestStateManager() + request_state_manager_2 = InMemoryRequestStateManager() + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + captured_notifications.append(message) + # Look for our special notification that indicates the tool is running + if isinstance(message.root, types.LoggingMessageNotification): + if message.root.params.data == "Tool started": + nonlocal tool_started + tool_started = True + + # First, start the client session and begin the long-running tool + async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + request_state_manager=request_state_manager_1, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None + # Capture the negotiated protocol version + captured_protocol_version = result.protocolVersion + + # Start a long-running tool in a task + async with anyio.create_task_group() as tg: + timed_out = anyio.Event() + + async def run_tool(): + nonlocal captured_request_id + captured_request_id = await session.request_call_tool( + "long_running_with_checkpoints", arguments={} + ) + + result = await session.join_call_tool( + captured_request_id, + request_read_timeout_seconds=timedelta(seconds=0.01), + done_on_timeout=False, + ) + + assert result is None + + timed_out.set() + + tg.start_soon(run_tool) + + # Wait for the tool to start and at least one notification + # and then kill the task group + while ( + not tool_started or not captured_request_id or len(request_state_manager_1._resume_tokens) == 0 + ): + await anyio.sleep(0.1) + + await timed_out.wait() + + tg.cancel_scope.cancel() + + # Store pre notifications and clear the captured notifications + # for the post-resumption check + captured_notifications_pre = captured_notifications.copy() + captured_notifications = [] + + # Now resume the session with the same mcp-session-id and protocol version + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id + if captured_protocol_version: + headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version + + assert len(request_state_manager_1._requests) == 1, str(request_state_manager_1._requests) + assert len(request_state_manager_1._resume_tokens) == 1 + + request_state_manager_2._requests = request_state_manager_1._requests.copy() + request_state_manager_2._resume_tokens = request_state_manager_1._resume_tokens.copy() + + async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + request_state_manager=request_state_manager_2, + ) as session: + # Don't initialize - just use the existing session + + # Resume the tool with the resumption token + assert captured_request_id is not None + + result = await session.join_call_tool(captured_request_id) + assert result is not None + + # We should get a complete result + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert "Completed" in result.content[0].text + + # We should have received the remaining notifications + assert len(captured_notifications) > 0 + + # Should not have the first notification + # Check that "Tool started" notification isn't repeated when resuming + assert not any( + isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started" + for n in captured_notifications + ) + # there is no intersection between pre and post notifications + assert not any(n in captured_notifications_pre for n in captured_notifications), ( + f"{captured_notifications_pre} -> {captured_notifications}" + ) + + assert len(request_state_manager_1._progress_callbacks) == 0 + assert len(request_state_manager_1._response_streams) == 0 + assert len(request_state_manager_2._progress_callbacks) == 0 + assert len(request_state_manager_2._resume_tokens) == 0 + assert len(request_state_manager_2._response_streams) == 0 + + @pytest.mark.anyio async def test_streamablehttp_server_sampling(basic_server, basic_server_url): """Test server-initiated sampling request through streamable HTTP transport."""