Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 35 additions & 37 deletions src/claude_agent_sdk/_internal/transport/subprocess_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
else _DEFAULT_MAX_BUFFER_SIZE
)
self._temp_files: list[str] = [] # Track temporary files for cleanup
self._write_lock: anyio.Lock = anyio.Lock()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. Write Lock Declaration [Core logic]

Initializes the anyio.Lock that serializes all transport writes. When parallel subagents handle concurrent control_request messages, each calls transport.write() to send responses. trio's underlying FdStream doesn't allow concurrent sends—this lock prevents the BusyResourceError crash.

Related: See comment 2 for where the lock is acquired.


🤖 Generated with Claude Code


def _find_cli(self) -> str:
"""Find Claude Code CLI binary."""
Expand Down Expand Up @@ -456,15 +457,14 @@ async def _handle_stderr(self) -> None:

async def close(self) -> None:
"""Close the transport and clean up resources."""
self._ready = False

# Clean up temporary files first (before early return)
for temp_file in self._temp_files:
with suppress(Exception):
Path(temp_file).unlink(missing_ok=True)
self._temp_files.clear()

if not self._process:
self._ready = False
return

# Close stderr task group if active
Expand All @@ -474,21 +474,19 @@ async def close(self) -> None:
await self._stderr_task_group.__aexit__(None, None, None)
self._stderr_task_group = None

# Close streams
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
# Close stdin stream (acquire lock to prevent race with concurrent writes)
async with self._write_lock:
self._ready = False # Set inside lock to prevent TOCTOU with write()
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
Comment on lines +477 to +483
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3. close() TOCTOU Prevention [Core logic]

Acquires _write_lock before closing stdin and setting _ready = False. This coordinates with write():

  1. If write() holds the lock, close() waits until the write finishes
  2. Once close() acquires the lock, it sets _ready = False inside the critical section
  3. Any subsequent write() will see _ready = False when it acquires the lock

Without this coordination, close() could clear the stream while write() was mid-send.


🤖 Generated with Claude Code


if self._stderr_stream:
with suppress(Exception):
await self._stderr_stream.aclose()
self._stderr_stream = None

if self._process.stdin:
with suppress(Exception):
await self._process.stdin.aclose()

# Terminate and wait for process
if self._process.returncode is None:
with suppress(ProcessLookupError):
Expand All @@ -506,37 +504,37 @@ async def close(self) -> None:

async def write(self, data: str) -> None:
"""Write raw data to the transport."""
# Check if ready (like TypeScript)
if not self._ready or not self._stdin_stream:
raise CLIConnectionError("ProcessTransport is not ready for writing")

# Check if process is still alive (like TypeScript)
if self._process and self._process.returncode is not None:
raise CLIConnectionError(
f"Cannot write to terminated process (exit code: {self._process.returncode})"
)
async with self._write_lock:
# All checks inside lock to prevent TOCTOU races with close()/end_input()
if not self._ready or not self._stdin_stream:
raise CLIConnectionError("ProcessTransport is not ready for writing")

if self._process and self._process.returncode is not None:
raise CLIConnectionError(
f"Cannot write to terminated process (exit code: {self._process.returncode})"
)

# Check for exit errors (like TypeScript)
if self._exit_error:
raise CLIConnectionError(
f"Cannot write to process that exited with error: {self._exit_error}"
) from self._exit_error
if self._exit_error:
raise CLIConnectionError(
f"Cannot write to process that exited with error: {self._exit_error}"
) from self._exit_error

try:
await self._stdin_stream.send(data)
except Exception as e:
self._ready = False # Mark as not ready (like TypeScript)
self._exit_error = CLIConnectionError(
f"Failed to write to process stdin: {e}"
)
raise self._exit_error from e
try:
await self._stdin_stream.send(data)
except Exception as e:
self._ready = False
self._exit_error = CLIConnectionError(
f"Failed to write to process stdin: {e}"
)
raise self._exit_error from e
Comment on lines +507 to +529
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2. write() Serialization [Core logic]

The core fix: all write logic runs inside async with self._write_lock. This includes:

  • State checks (_ready, _stdin_stream, returncode, _exit_error)
  • The actual stdin_stream.send() call

Moving checks inside the lock prevents TOCTOU races—without this, a write could pass all checks, then close() could clear _stdin_stream, and the write would fail on a None stream.


🤖 Generated with Claude Code


async def end_input(self) -> None:
"""End the input stream (close stdin)."""
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None
async with self._write_lock:
if self._stdin_stream:
with suppress(Exception):
await self._stdin_stream.aclose()
self._stdin_stream = None

def read_messages(self) -> AsyncIterator[dict[str, Any]]:
"""Read and parse messages from the transport."""
Expand Down
133 changes: 133 additions & 0 deletions tests/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,136 @@ def test_sandbox_network_config(self):
assert network["allowLocalBinding"] is True
assert network["httpProxyPort"] == 8080
assert network["socksProxyPort"] == 8081

def test_concurrent_writes_are_serialized(self):
"""Test that concurrent write() calls are serialized by the lock.

When parallel subagents invoke MCP tools, they trigger concurrent write()
calls. Without the _write_lock, trio raises BusyResourceError.

Uses a real subprocess with the same stream setup as production:
process.stdin -> TextSendStream
"""

async def _test():
import sys
from subprocess import PIPE

from anyio.streams.text import TextSendStream

# Create a real subprocess that consumes stdin (cross-platform)
process = await anyio.open_process(
[sys.executable, "-c", "import sys; sys.stdin.read()"],
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)

try:
transport = SubprocessCLITransport(
prompt="test",
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
)

# Same setup as production: TextSendStream wrapping process.stdin
transport._ready = True
transport._process = MagicMock(returncode=None)
transport._stdin_stream = TextSendStream(process.stdin)

# Spawn concurrent writes - the lock should serialize them
num_writes = 10
errors: list[Exception] = []

async def do_write(i: int):
try:
await transport.write(f'{{"msg": {i}}}\n')
except Exception as e:
errors.append(e)

async with anyio.create_task_group() as tg:
for i in range(num_writes):
tg.start_soon(do_write, i)

# All writes should succeed - the lock serializes them
assert len(errors) == 0, f"Got errors: {errors}"
finally:
process.terminate()
await process.wait()

anyio.run(_test, backend="trio")
Comment on lines +651 to +706
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4. Concurrency Tests [Test]

Two complementary tests verify the fix:

Positive test (this method): Spawns 10 concurrent write() calls against a real subprocess with TextSendStream—the same stream type used in production. All writes should succeed because the lock serializes them.

Negative test (test_concurrent_writes_fail_without_lock): Replaces the lock with a no-op, proving the race condition exists. Without the lock, trio raises BusyResourceError: another task is using this stream for send.

Both tests use a real process rather than mocks to ensure the concurrency behavior matches production.


🤖 Generated with Claude Code


def test_concurrent_writes_fail_without_lock(self):
"""Verify that without the lock, concurrent writes cause BusyResourceError.

Uses a real subprocess with the same stream setup as production.
"""

async def _test():
import sys
from contextlib import asynccontextmanager
from subprocess import PIPE

from anyio.streams.text import TextSendStream

# Create a real subprocess that consumes stdin (cross-platform)
process = await anyio.open_process(
[sys.executable, "-c", "import sys; sys.stdin.read()"],
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)

try:
transport = SubprocessCLITransport(
prompt="test",
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
)

# Same setup as production
transport._ready = True
transport._process = MagicMock(returncode=None)
transport._stdin_stream = TextSendStream(process.stdin)

# Replace lock with no-op to trigger the race condition
class NoOpLock:
@asynccontextmanager
async def __call__(self):
yield

async def __aenter__(self):
return self

async def __aexit__(self, *args):
pass

transport._write_lock = NoOpLock()

# Spawn concurrent writes - should fail without lock
num_writes = 10
errors: list[Exception] = []

async def do_write(i: int):
try:
await transport.write(f'{{"msg": {i}}}\n')
except Exception as e:
errors.append(e)

async with anyio.create_task_group() as tg:
for i in range(num_writes):
tg.start_soon(do_write, i)

# Should have gotten errors due to concurrent access
assert len(errors) > 0, (
"Expected errors from concurrent access, but got none"
)

# Check that at least one error mentions the concurrent access
error_strs = [str(e) for e in errors]
assert any("another task" in s for s in error_strs), (
f"Expected 'another task' error, got: {error_strs}"
)
finally:
process.terminate()
await process.wait()

anyio.run(_test, backend="trio")
Loading