diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6536272d9..c2402f37e 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -333,107 +333,114 @@ async def _receive_loop(self) -> None: self._read_stream, self._write_stream, ): - try: - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - elif isinstance(message.message.root, JSONRPCRequest): - try: - validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - responder = RequestResponder( - request_id=message.message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, - request=validated_request, - session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), - message_metadata=message.metadata, - ) - self._in_flight[responder.request_id] = responder - await self._received_request(responder) - - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) - except Exception as e: - # For request validation errors, send a proper JSON-RPC error - # response instead of crashing the server - logging.warning(f"Failed to validate request: {e}") - logging.debug(f"Message that failed validation: {message.message.root}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=message.message.root.id, - error=ErrorData( - code=INVALID_PARAMS, - message="Invalid request parameters", - data="", - ), - ) - session_message = SessionMessage(message=JSONRPCMessage(error_response)) - await self._write_stream.send(session_message) - - elif isinstance(message.message.root, JSONRPCNotification): - try: - notification = self._receive_notification_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - await self._in_flight[cancelled_id].cancel() + async with anyio.create_task_group() as tg: + try: + async for message in self._read_stream: + if isinstance(message, Exception): + await self._handle_incoming(message) + elif isinstance(message.message.root, JSONRPCRequest): + try: + validated_request = self._receive_request_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + responder = RequestResponder( + request_id=message.message.root.id, + request_meta=validated_request.root.params.meta + if validated_request.root.params + else None, + request=validated_request, + session=self, + on_complete=lambda r: self._in_flight.pop(r.request_id, None), + message_metadata=message.metadata, + ) + + async def _handle_received_request() -> None: + await self._received_request(responder) + if not responder._completed: # type: ignore[reportPrivateUsage] + await self._handle_incoming(responder) + + self._in_flight[responder.request_id] = responder + tg.start_soon(_handle_received_request) + + if not responder._completed: # type: ignore[reportPrivateUsage] + await self._handle_incoming(responder) + except Exception as e: + # For request validation errors, send a proper JSON-RPC error + # response instead of crashing the server + logging.warning(f"Failed to validate request: {e}") + logging.debug(f"Message that failed validation: {message.message.root}") + error_response = JSONRPCError( + jsonrpc="2.0", + id=message.message.root.id, + error=ErrorData( + code=INVALID_PARAMS, + message="Invalid request parameters", + data="", + ), + ) + session_message = SessionMessage(message=JSONRPCMessage(error_response)) + await self._write_stream.send(session_message) + + elif isinstance(message.message.root, JSONRPCNotification): + try: + notification = self._receive_notification_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + # Handle progress notifications callback + if isinstance(notification.root, ProgressNotification): + progress_token = notification.root.params.progressToken + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) + await self._received_notification(notification) + await self._handle_incoming(notification) + except Exception as e: + # For other validation errors, log and continue + logging.warning( + f"Failed to validate notification: {e}. " f"Message was: {message.message.root}" + ) + else: # Response or error + stream = self._response_streams.pop(message.message.root.id, None) + if stream: + await stream.send(message.message.root) else: - # Handle progress notifications callback - if isinstance(notification.root, ProgressNotification): - progress_token = notification.root.params.progressToken - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) - await self._received_notification(notification) - await self._handle_incoming(notification) - except Exception as e: - # For other validation errors, log and continue - logging.warning( - f"Failed to validate notification: {e}. " f"Message was: {message.message.root}" - ) - else: # Response or error - stream = self._response_streams.pop(message.message.root.id, None) - if stream: - await stream.send(message.message.root) - else: - await self._handle_incoming( - RuntimeError("Received response with an unknown " f"request ID: {message}") - ) - - except anyio.ClosedResourceError: - # This is expected when the client disconnects abruptly. - # Without this handler, the exception would propagate up and - # crash the server's task group. - logging.debug("Read stream closed by client") - except Exception as e: - # Other exceptions are not expected and should be logged. We purposefully - # catch all exceptions here to avoid crashing the server. - logging.exception(f"Unhandled exception in receive loop: {e}") - finally: - # after the read stream is closed, we need to send errors - # to any pending requests - for id, stream in self._response_streams.items(): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - try: - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - except Exception: - # Stream might already be closed - pass - self._response_streams.clear() + await self._handle_incoming( + RuntimeError("Received response with an unknown " f"request ID: {message}") + ) + + except anyio.ClosedResourceError: + # This is expected when the client disconnects abruptly. + # Without this handler, the exception would propagate up and + # crash the server's task group. + logging.debug("Read stream closed by client") + except Exception as e: + # Other exceptions are not expected and should be logged. We purposefully + # catch all exceptions here to avoid crashing the server. + logging.exception(f"Unhandled exception in receive loop: {e}") + finally: + # after the read stream is closed, we need to send errors + # to any pending requests + for id, stream in self._response_streams.items(): + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + try: + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + await stream.aclose() + except Exception: + # Stream might already be closed + pass + self._response_streams.clear() async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index a3f6affda..f54c8dac4 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,3 +1,4 @@ +import anyio import pytest from mcp.client.session import ClientSession @@ -56,3 +57,99 @@ async def test_sampling_tool(message: str): assert result.isError is True assert isinstance(result.content[0], TextContent) assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported" + + +@pytest.mark.anyio +async def test_concurrent_sampling_callback(): + """Test multiple concurrent sampling calls using time-sort verification.""" + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + # Track completion order using time-sort approach + completion_order = [] + + async def sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + # Extract delay from the message content (e.g., "delay_0.3") + assert isinstance(params.messages[0].content, TextContent) + message_text = params.messages[0].content.text + if message_text.startswith("delay_"): + delay = float(message_text.split("_")[1]) + # Simulate different LLM response times + await anyio.sleep(delay) + completion_order.append(delay) + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=f"Response after {delay}s"), + model="test-model", + stopReason="endTurn", + ) + + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="Default response"), + model="test-model", + stopReason="endTurn", + ) + + @server.tool("concurrent_sampling_tool") + async def concurrent_sampling_tool(): + """Tool that makes multiple concurrent sampling calls.""" + # Use TaskGroup to make multiple concurrent sampling calls + # Using out-of-order durations: 0.6s, 0.2s, 0.4s + # If concurrent, should complete in order: 0.2s, 0.4s, 0.6s + async with anyio.create_task_group() as tg: + results = {} + + async def make_sampling_call(call_id: str, delay: float): + result = await server.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", + content=TextContent(type="text", text=f"delay_{delay}"), + ) + ], + max_tokens=100, + ) + results[call_id] = result + + # Start operations with out-of-order timing + tg.start_soon(make_sampling_call, "slow_call", 0.6) # Should finish last + tg.start_soon(make_sampling_call, "fast_call", 0.2) # Should finish first + tg.start_soon(make_sampling_call, "medium_call", 0.4) # Should finish middle + + # Combine results to show all completed + combined_response = " | ".join( + [ + results["slow_call"].content.text, + results["fast_call"].content.text, + results["medium_call"].content.text, + ] + ) + + return combined_response + + # Test concurrent sampling calls with time-sort verification + async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session: + # Make a request that triggers multiple concurrent sampling calls + result = await client_session.call_tool("concurrent_sampling_tool", {}) + + assert result.isError is False + assert isinstance(result.content[0], TextContent) + + # Verify all sampling calls completed with expected responses + expected_result = "Response after 0.6s | Response after 0.2s | Response after 0.4s" + assert result.content[0].text == expected_result + + # Key test: verify concurrent execution using time-sort + # Started in order: 0.6s, 0.2s, 0.4s + # Should complete in order: 0.2s, 0.4s, 0.6s (fastest first) + assert len(completion_order) == 3 + assert completion_order == [ + 0.2, + 0.4, + 0.6, + ], f"Expected [0.2, 0.4, 0.6] but got {completion_order}" diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 864e0d1b4..b7cd51406 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -17,6 +17,7 @@ ClientNotification, ClientRequest, EmptyResult, + TextContent, ) @@ -177,3 +178,74 @@ async def mock_server(): await ev_closed.wait() with anyio.fail_after(1): await ev_response.wait() + + +@pytest.mark.anyio +async def test_async_request_handling_with_taskgroup(): + """Test that multiple sampling requests are handled asynchronously.""" + # Track completion order + completion_order = [] + + def make_server() -> Server: + server = Server(name="AsyncTestServer") + + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal completion_order + + if name.startswith("timed_tool"): + # Extract wait time from tool name (e.g., "timed_tool_0.2") + wait_time = float(name.split("_")[-1]) + + # Wait for the specified time + await anyio.sleep(wait_time) + + # Record completion + completion_order.append(wait_time) + + return [TextContent(type="text", text=f"Waited {wait_time}s")] + + raise ValueError(f"Unknown tool: {name}") + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="timed_tool_0.1", + description="Tool that waits 0.1s", + inputSchema={}, + ), + types.Tool( + name="timed_tool_0.2", + description="Tool that waits 0.2s", + inputSchema={}, + ), + types.Tool( + name="timed_tool_0.05", + description="Tool that waits 0.05s", + inputSchema={}, + ), + ] + + return server + + async with create_connected_server_and_client_session(make_server()) as client_session: + # Test basic async handling with a single request + result = await client_session.send_request( + ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams(name="timed_tool_0.1", arguments={}), + ) + ), + types.CallToolResult, + ) + + # Verify the request completed successfully + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Waited 0.1s" + assert len(completion_order) == 1 + assert completion_order[0] == 0.1 + + # Verify no pending requests remain + assert len(client_session._in_flight) == 0 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 1ffcc13b0..f153e3a62 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -181,7 +181,9 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: related_request_id=ctx.request_id, # need for stream association ) - await anyio.sleep(0.1) + # need to wait for long enough that the client can + # reliably stop the tool before this finishes + await anyio.sleep(0.3) await ctx.session.send_log_message( level="info", @@ -190,6 +192,18 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: related_request_id=ctx.request_id, ) + # Adding another message just to make it even less + # likely that this tool will exit before the client + # can stop it + await anyio.sleep(0.3) + + await ctx.session.send_log_message( + level="info", + data="Tool is done", + logger="tool", + related_request_id=ctx.request_id, + ) + return [TextContent(type="text", text="Completed!")] elif name == "test_sampling_tool": @@ -1114,6 +1128,11 @@ async def run_tool(): await anyio.sleep(0.1) tg.cancel_scope.cancel() + # Make sure we only have one notification.. otherwise the test is flaky + # More than one notification means the tool likely could have finished + # already and will not call the message handler again upon resumption + assert len(captured_notifications) == 1 + # Store pre notifications and clear the captured notifications # for the post-resumption check captured_notifications_pre = captured_notifications.copy() @@ -1140,6 +1159,10 @@ async def run_tool(): metadata = ClientMessageMetadata( resumption_token=captured_resumption_token, ) + # We need to wait for the tool to send another message so this doesn't + # deadlock. Fixing is out of scope for this PR. More details in + # https://github.com/modelcontextprotocol/python-sdk/issues/860 + await anyio.sleep(0.2) result = await session.send_request( types.ClientRequest( types.CallToolRequest( @@ -1157,7 +1180,7 @@ async def run_tool(): assert "Completed" in result.content[0].text # We should have received the remaining notifications - assert len(captured_notifications) > 0 + assert len(captured_notifications) == 2 # Should not have the first notification # Check that "Tool started" notification isn't repeated when resuming