Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from mcp.shared.exceptions import McpError
from mcp.types import (
CONNECTION_CLOSED,
CancelledNotification,
ClientNotification,
ClientRequest,
Expand Down Expand Up @@ -359,6 +360,13 @@ async def _receive_loop(self) -> None:
)
)

# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()

async def _received_request(
self, responder: RequestResponder[ReceiveRequestT, SendResultT]
) -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ class JSONRPCResponse(BaseModel):
model_config = ConfigDict(extra="allow")


# SDK error codes
CONNECTION_CLOSED = -32000
# REQUEST_TIMEOUT = -32001 # the typescript sdk uses this

# Standard JSON-RPC error codes
PARSE_ERROR = -32700
INVALID_REQUEST = -32600
Expand Down
59 changes: 58 additions & 1 deletion tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from mcp.client.session import ClientSession
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.shared.memory import (
create_client_server_memory_streams,
create_connected_server_and_client_session,
)
from mcp.types import (
CancelledNotification,
CancelledNotificationParams,
Expand Down Expand Up @@ -124,3 +127,57 @@ async def make_request(client_session):
# Give cancellation time to process
with anyio.fail_after(1):
await ev_cancelled.wait()


@pytest.mark.anyio
async def test_connection_closed():
"""
Test that pending requests are cancelled when the connection is closed remotely.
"""

ev_closed = anyio.Event()
ev_response = anyio.Event()

async with create_client_server_memory_streams() as (
client_streams,
server_streams,
):
client_read, client_write = client_streams
server_read, server_write = server_streams

async def make_request(client_session):
"""Send a request in a separate task"""
nonlocal ev_response
try:
# any request will do
await client_session.initialize()
pytest.fail("Request should have errored")
except McpError as e:
# Expected - request errored
assert "Connection closed" in str(e)
ev_response.set()

async def mock_server():
"""Wait for a request, then close the connection"""
nonlocal ev_closed
# Wait for a request
await server_read.receive()
# Close the connection, as if the server exited
server_write.close()
server_read.close()
ev_closed.set()

async with (
anyio.create_task_group() as tg,
ClientSession(
read_stream=client_read,
write_stream=client_write,
) as client_session,
):
tg.start_soon(make_request, client_session)
tg.start_soon(mock_server)

with anyio.fail_after(1):
await ev_closed.wait()
with anyio.fail_after(1):
await ev_response.wait()