Skip to content
Closed
207 changes: 107 additions & 100 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,107 +333,114 @@ async def _receive_loop(self) -> None:
self._read_stream,
self._write_stream,
):
try:
async for message in self._read_stream:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception as e:
# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning(f"Failed to validate request: {e}")
logging.debug(f"Message that failed validation: {message.message.root}")
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.root.id,
error=ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
),
)
session_message = SessionMessage(message=JSONRPCMessage(error_response))
await self._write_stream.send(session_message)

elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
async with anyio.create_task_group() as tg:
try:
async for message in self._read_stream:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)

async def _handle_received_request() -> None:
await self._received_request(responder)
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is duplicating line 366 below so we might be calling self._handle_incoming twice for every request? Should we be removing 365-366?


self._in_flight[responder.request_id] = responder
tg.start_soon(_handle_received_request)

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception as e:
# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning(f"Failed to validate request: {e}")
logging.debug(f"Message that failed validation: {message.message.root}")
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.root.id,
error=ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
),
)
session_message = SessionMessage(message=JSONRPCMessage(error_response))
await self._write_stream.send(session_message)

elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None)
if stream:
await stream.send(message.message.root)
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None)
if stream:
await stream.send(message.message.root)
else:
await self._handle_incoming(
RuntimeError("Received response with an unknown " f"request ID: {message}")
)

except anyio.ClosedResourceError:
# This is expected when the client disconnects abruptly.
# Without this handler, the exception would propagate up and
# crash the server's task group.
logging.debug("Read stream closed by client")
except Exception as e:
# Other exceptions are not expected and should be logged. We purposefully
# catch all exceptions here to avoid crashing the server.
logging.exception(f"Unhandled exception in receive loop: {e}")
finally:
# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception:
# Stream might already be closed
pass
self._response_streams.clear()
await self._handle_incoming(
RuntimeError("Received response with an unknown " f"request ID: {message}")
)

except anyio.ClosedResourceError:
# This is expected when the client disconnects abruptly.
# Without this handler, the exception would propagate up and
# crash the server's task group.
logging.debug("Read stream closed by client")
except Exception as e:
# Other exceptions are not expected and should be logged. We purposefully
# catch all exceptions here to avoid crashing the server.
logging.exception(f"Unhandled exception in receive loop: {e}")
finally:
# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception:
# Stream might already be closed
pass
self._response_streams.clear()

async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""
Expand Down
97 changes: 97 additions & 0 deletions tests/client/test_sampling_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import anyio
import pytest

from mcp.client.session import ClientSession
Expand Down Expand Up @@ -56,3 +57,99 @@ async def test_sampling_tool(message: str):
assert result.isError is True
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"


@pytest.mark.anyio
async def test_concurrent_sampling_callback():
"""Test multiple concurrent sampling calls using time-sort verification."""
from mcp.server.fastmcp import FastMCP

server = FastMCP("test")

# Track completion order using time-sort approach
completion_order = []

async def sampling_callback(
context: RequestContext[ClientSession, None],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
# Extract delay from the message content (e.g., "delay_0.3")
assert isinstance(params.messages[0].content, TextContent)
message_text = params.messages[0].content.text
if message_text.startswith("delay_"):
delay = float(message_text.split("_")[1])
# Simulate different LLM response times
await anyio.sleep(delay)
completion_order.append(delay)
return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text=f"Response after {delay}s"),
model="test-model",
stopReason="endTurn",
)

return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="Default response"),
model="test-model",
stopReason="endTurn",
)

@server.tool("concurrent_sampling_tool")
async def concurrent_sampling_tool():
"""Tool that makes multiple concurrent sampling calls."""
# Use TaskGroup to make multiple concurrent sampling calls
# Using out-of-order durations: 0.6s, 0.2s, 0.4s
# If concurrent, should complete in order: 0.2s, 0.4s, 0.6s
async with anyio.create_task_group() as tg:
results = {}

async def make_sampling_call(call_id: str, delay: float):
result = await server.get_context().session.create_message(
messages=[
SamplingMessage(
role="user",
content=TextContent(type="text", text=f"delay_{delay}"),
)
],
max_tokens=100,
)
results[call_id] = result

# Start operations with out-of-order timing
tg.start_soon(make_sampling_call, "slow_call", 0.6) # Should finish last
tg.start_soon(make_sampling_call, "fast_call", 0.2) # Should finish first
tg.start_soon(make_sampling_call, "medium_call", 0.4) # Should finish middle

# Combine results to show all completed
combined_response = " | ".join(
[
results["slow_call"].content.text,
results["fast_call"].content.text,
results["medium_call"].content.text,
]
)

return combined_response

# Test concurrent sampling calls with time-sort verification
async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session:
# Make a request that triggers multiple concurrent sampling calls
result = await client_session.call_tool("concurrent_sampling_tool", {})

assert result.isError is False
assert isinstance(result.content[0], TextContent)

# Verify all sampling calls completed with expected responses
expected_result = "Response after 0.6s | Response after 0.2s | Response after 0.4s"
assert result.content[0].text == expected_result

# Key test: verify concurrent execution using time-sort
# Started in order: 0.6s, 0.2s, 0.4s
# Should complete in order: 0.2s, 0.4s, 0.6s (fastest first)
assert len(completion_order) == 3
assert completion_order == [
0.2,
0.4,
0.6,
], f"Expected [0.2, 0.4, 0.6] but got {completion_order}"
Loading
Loading