Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 16 additions & 36 deletions src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import logging
import os
import sys
from contextlib import asynccontextmanager
from contextlib import AsyncExitStack, asynccontextmanager
from pathlib import Path
from typing import Literal, TextIO

import anyio
import anyio.lowlevel
from anyio.abc import Process
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -107,33 +106,19 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
command = _get_executable_command(server.command)

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

try:
command = _get_executable_command(server.command)

# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
errlog=errlog,
cwd=server.cwd,
)
except OSError:
# Clean up streams if process creation fails
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
raise
Comment on lines -130 to -136
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually never triggers OSError here. anyio.open_process doesn't trigger.

Maybe the Windows logic does, but even if it does... The streams don't need to be closed because they are not even open yet, so just removing the except is fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the test suite I'm wrong... I'm not sure how.

# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
errlog=errlog,
cwd=server.cwd,
)

async def stdout_reader():
assert process.stdout, "Opened process is missing stdout"
Expand Down Expand Up @@ -177,14 +162,13 @@ async def stdin_writer():
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

async with (
anyio.create_task_group() as tg,
process,
):
async with anyio.create_task_group() as tg, process:
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)

try:
yield read_stream, write_stream
async with read_stream, write_stream:
yield read_stream, write_stream
finally:
# MCP spec: stdio shutdown sequence
# 1. Close input stream to server
Expand All @@ -208,10 +192,6 @@ async def stdin_writer():
except ProcessLookupError:
# Process already exited, which is fine
pass
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to close read_stream and write_stream, just add the async context manager block above, and no need for read_stream_writer and write_stream_reader because they are open and closed within the tasks above.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it seems I was wrong... 👀



def _get_executable_command(command: str) -> str:
Expand Down
16 changes: 4 additions & 12 deletions src/mcp/shared/session.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Readability changes - I can revert if wanted.

Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,7 @@ async def send_request(
self._progress_callbacks[request_id] = progress_callback

try:
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request_data,
)

jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))

# request read timeout takes precedence over session read timeout
Expand Down Expand Up @@ -329,10 +324,7 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er
await self._write_stream.send(session_message)

async def _receive_loop(self) -> None:
async with (
self._read_stream,
self._write_stream,
):
async with self._read_stream, self._write_stream:
try:
async for message in self._read_stream:
if isinstance(message, Exception):
Expand Down Expand Up @@ -418,10 +410,10 @@ async def _receive_loop(self) -> None:
# Without this handler, the exception would propagate up and
# crash the server's task group.
logging.debug("Read stream closed by client")
except Exception as e:
except Exception:
# Other exceptions are not expected and should be logged. We purposefully
# catch all exceptions here to avoid crashing the server.
logging.exception(f"Unhandled exception in receive loop: {e}")
logging.exception("Unhandled exception in receive loop")
finally:
# after the read stream is closed, we need to send errors
# to any pending requests
Expand Down
Loading