Skip to content

Commit de6dd61

Browse files
authored
Merge pull request #1 from EiffL/streamable_http_disconnect
fix: handle transport exceptions and ensure proper cleanup of in-flig…
2 parents 07ae8c0 + 671d077 commit de6dd61

File tree

3 files changed

+83
-1
lines changed

3 files changed

+83
-1
lines changed

src/mcp/client/streamable_http.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,13 @@ async def _handle_sse_response(
337337
break
338338
except Exception as e:
339339
logger.exception("Error reading SSE stream:")
340-
await ctx.read_stream_writer.send(e)
340+
# Forward the transport exception to the session read stream so it can
341+
# fail any in-flight requests, then close the stream to signal that
342+
# no further messages will arrive.
343+
try:
344+
await ctx.read_stream_writer.send(e)
345+
finally:
346+
await ctx.read_stream_writer.aclose()
341347

342348
async def _handle_unexpected_content_type(
343349
self,

src/mcp/shared/session.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,23 @@ async def _receive_loop(self) -> None:
336336
try:
337337
async for message in self._read_stream:
338338
if isinstance(message, Exception):
339+
# Transport-level exception. Forward it to the incoming
340+
# handler for logging/observation, then fail all
341+
# in-flight requests so callers don't hang forever.
339342
await self._handle_incoming(message)
343+
error = ErrorData(code=CONNECTION_CLOSED, message=str(message))
344+
# Send error to any pending request response streams immediately
345+
for id, stream in list(self._response_streams.items()):
346+
try:
347+
await stream.send(
348+
JSONRPCError(jsonrpc="2.0", id=id, error=error)
349+
)
350+
await stream.aclose()
351+
except Exception:
352+
pass
353+
self._response_streams.clear()
354+
# Break out of the receive loop; connection is no longer usable.
355+
break
340356
elif isinstance(message.message.root, JSONRPCRequest):
341357
try:
342358
validated_request = self._receive_request_type.model_validate(

tests/shared/test_streamable_http.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,66 @@ def test_streamable_http_transport_init_validation():
603603
StreamableHTTPServerTransport(mcp_session_id="test\n")
604604

605605

606+
@pytest.mark.anyio
607+
async def test_streamable_http_mid_call_disconnect(basic_server_port: int):
608+
"""A long-running tool call should error quickly if the server dies mid-call.
609+
610+
We use the existing server implementation's `wait_for_lock_with_notification` tool which
611+
blocks waiting on an internal lock. We kill the server process after the tool starts and
612+
assert the client surfaces a CONNECTION_CLOSED McpError instead of hanging.
613+
"""
614+
615+
# Launch a dedicated server process (don't reuse basic_server fixture so we can kill it)
616+
proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True)
617+
proc.start()
618+
619+
# Wait for server readiness (reuse pattern from fixtures)
620+
start_time = time.time()
621+
while time.time() - start_time < 10:
622+
try:
623+
with socket.create_connection(("127.0.0.1", basic_server_port), timeout=0.2):
624+
break
625+
except OSError:
626+
time.sleep(0.1)
627+
else: # pragma: no cover
628+
proc.kill(); proc.join(timeout=2)
629+
pytest.fail("Server failed to start in time")
630+
631+
server_url = f"http://127.0.0.1:{basic_server_port}/mcp"
632+
633+
try:
634+
async with streamablehttp_client(server_url) as (read_stream, write_stream, _):
635+
async with ClientSession(read_stream, write_stream) as session:
636+
await session.initialize()
637+
await session.list_tools()
638+
639+
result: dict[str, object] = {}
640+
641+
async def invoke():
642+
try:
643+
await session.call_tool("wait_for_lock_with_notification", {}) # pragma: no cover
644+
result["ok"] = True
645+
except McpError as e:
646+
result["err"] = e
647+
648+
async with anyio.create_task_group() as tg:
649+
tg.start_soon(invoke)
650+
# Give the request a moment to reach the server & tool to start (it sends a log notification)
651+
await anyio.sleep(0.6)
652+
proc.kill(); proc.join(timeout=2)
653+
# Wait for propagated disconnect
654+
with anyio.fail_after(10):
655+
while "err" not in result:
656+
await anyio.sleep(0.2)
657+
658+
err = result.get("err")
659+
assert isinstance(err, McpError), "Expected McpError from mid-call disconnect"
660+
assert err.error.code == types.CONNECTION_CLOSED
661+
finally:
662+
if proc.is_alive(): # Safety cleanup
663+
proc.kill(); proc.join(timeout=2)
664+
665+
606666
def test_session_termination(basic_server: None, basic_server_url: str):
607667
"""Test session termination via DELETE and subsequent request handling."""
608668
response = requests.post(

0 commit comments

Comments
 (0)