Skip to content

Commit 2d67166

Browse files
authored
fix: add write lock to prevent concurrent transport writes (#391)
## TL;DR Adds a write lock to `SubprocessCLITransport` to prevent concurrent writes from parallel subagents. --- ## Overview When multiple subagents run in parallel and invoke MCP tools, the CLI sends concurrent `control_request` messages. Each handler tries to write a response back to the subprocess stdin at the same time. Trio's `TextSendStream` isn't thread-safe for concurrent access, so this causes `BusyResourceError`. This PR adds an `anyio.Lock` around all write operations (`write()`, `end_input()`, and the stdin-closing part of `close()`). The lock serializes concurrent writes so they happen one at a time. The `_ready` flag is now set inside the lock during `close()` to prevent a TOCTOU race where `write()` checks `_ready`, then `close()` sets it and closes the stream before `write()` actually sends data. --- ## Call Flow ```mermaid flowchart TD A["write()<br/>subprocess_cli.py:505"] --> B["acquire _write_lock<br/>subprocess_cli.py:507"] B --> C["check _ready & stream<br/>subprocess_cli.py:509"] C --> D["_stdin_stream.send()<br/>subprocess_cli.py:523"] E["close()<br/>subprocess_cli.py:458"] --> F["acquire _write_lock<br/>subprocess_cli.py:478"] F --> G["set _ready = False<br/>subprocess_cli.py:479"] G --> H["close _stdin_stream<br/>subprocess_cli.py:481"] I["end_input()<br/>subprocess_cli.py:531"] --> J["acquire _write_lock<br/>subprocess_cli.py:533"] J --> K["close _stdin_stream<br/>subprocess_cli.py:535"] ```
1 parent 00332f3 commit 2d67166

File tree

2 files changed

+168
-37
lines changed

2 files changed

+168
-37
lines changed

src/claude_agent_sdk/_internal/transport/subprocess_cli.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
else _DEFAULT_MAX_BUFFER_SIZE
6666
)
6767
self._temp_files: list[str] = [] # Track temporary files for cleanup
68+
self._write_lock: anyio.Lock = anyio.Lock()
6869

6970
def _find_cli(self) -> str:
7071
"""Find Claude Code CLI binary."""
@@ -471,15 +472,14 @@ async def _handle_stderr(self) -> None:
471472

472473
async def close(self) -> None:
473474
"""Close the transport and clean up resources."""
474-
self._ready = False
475-
476475
# Clean up temporary files first (before early return)
477476
for temp_file in self._temp_files:
478477
with suppress(Exception):
479478
Path(temp_file).unlink(missing_ok=True)
480479
self._temp_files.clear()
481480

482481
if not self._process:
482+
self._ready = False
483483
return
484484

485485
# Close stderr task group if active
@@ -489,21 +489,19 @@ async def close(self) -> None:
489489
await self._stderr_task_group.__aexit__(None, None, None)
490490
self._stderr_task_group = None
491491

492-
# Close streams
493-
if self._stdin_stream:
494-
with suppress(Exception):
495-
await self._stdin_stream.aclose()
496-
self._stdin_stream = None
492+
# Close stdin stream (acquire lock to prevent race with concurrent writes)
493+
async with self._write_lock:
494+
self._ready = False # Set inside lock to prevent TOCTOU with write()
495+
if self._stdin_stream:
496+
with suppress(Exception):
497+
await self._stdin_stream.aclose()
498+
self._stdin_stream = None
497499

498500
if self._stderr_stream:
499501
with suppress(Exception):
500502
await self._stderr_stream.aclose()
501503
self._stderr_stream = None
502504

503-
if self._process.stdin:
504-
with suppress(Exception):
505-
await self._process.stdin.aclose()
506-
507505
# Terminate and wait for process
508506
if self._process.returncode is None:
509507
with suppress(ProcessLookupError):
@@ -521,37 +519,37 @@ async def close(self) -> None:
521519

522520
async def write(self, data: str) -> None:
523521
"""Write raw data to the transport."""
524-
# Check if ready (like TypeScript)
525-
if not self._ready or not self._stdin_stream:
526-
raise CLIConnectionError("ProcessTransport is not ready for writing")
527-
528-
# Check if process is still alive (like TypeScript)
529-
if self._process and self._process.returncode is not None:
530-
raise CLIConnectionError(
531-
f"Cannot write to terminated process (exit code: {self._process.returncode})"
532-
)
522+
async with self._write_lock:
523+
# All checks inside lock to prevent TOCTOU races with close()/end_input()
524+
if not self._ready or not self._stdin_stream:
525+
raise CLIConnectionError("ProcessTransport is not ready for writing")
526+
527+
if self._process and self._process.returncode is not None:
528+
raise CLIConnectionError(
529+
f"Cannot write to terminated process (exit code: {self._process.returncode})"
530+
)
533531

534-
# Check for exit errors (like TypeScript)
535-
if self._exit_error:
536-
raise CLIConnectionError(
537-
f"Cannot write to process that exited with error: {self._exit_error}"
538-
) from self._exit_error
532+
if self._exit_error:
533+
raise CLIConnectionError(
534+
f"Cannot write to process that exited with error: {self._exit_error}"
535+
) from self._exit_error
539536

540-
try:
541-
await self._stdin_stream.send(data)
542-
except Exception as e:
543-
self._ready = False # Mark as not ready (like TypeScript)
544-
self._exit_error = CLIConnectionError(
545-
f"Failed to write to process stdin: {e}"
546-
)
547-
raise self._exit_error from e
537+
try:
538+
await self._stdin_stream.send(data)
539+
except Exception as e:
540+
self._ready = False
541+
self._exit_error = CLIConnectionError(
542+
f"Failed to write to process stdin: {e}"
543+
)
544+
raise self._exit_error from e
548545

549546
async def end_input(self) -> None:
550547
"""End the input stream (close stdin)."""
551-
if self._stdin_stream:
552-
with suppress(Exception):
553-
await self._stdin_stream.aclose()
554-
self._stdin_stream = None
548+
async with self._write_lock:
549+
if self._stdin_stream:
550+
with suppress(Exception):
551+
await self._stdin_stream.aclose()
552+
self._stdin_stream = None
555553

556554
def read_messages(self) -> AsyncIterator[dict[str, Any]]:
557555
"""Read and parse messages from the transport."""

tests/test_transport.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,3 +693,136 @@ def test_build_command_without_tools(self):
693693

694694
cmd = transport._build_command()
695695
assert "--tools" not in cmd
696+
697+
def test_concurrent_writes_are_serialized(self):
698+
"""Test that concurrent write() calls are serialized by the lock.
699+
700+
When parallel subagents invoke MCP tools, they trigger concurrent write()
701+
calls. Without the _write_lock, trio raises BusyResourceError.
702+
703+
Uses a real subprocess with the same stream setup as production:
704+
process.stdin -> TextSendStream
705+
"""
706+
707+
async def _test():
708+
import sys
709+
from subprocess import PIPE
710+
711+
from anyio.streams.text import TextSendStream
712+
713+
# Create a real subprocess that consumes stdin (cross-platform)
714+
process = await anyio.open_process(
715+
[sys.executable, "-c", "import sys; sys.stdin.read()"],
716+
stdin=PIPE,
717+
stdout=PIPE,
718+
stderr=PIPE,
719+
)
720+
721+
try:
722+
transport = SubprocessCLITransport(
723+
prompt="test",
724+
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
725+
)
726+
727+
# Same setup as production: TextSendStream wrapping process.stdin
728+
transport._ready = True
729+
transport._process = MagicMock(returncode=None)
730+
transport._stdin_stream = TextSendStream(process.stdin)
731+
732+
# Spawn concurrent writes - the lock should serialize them
733+
num_writes = 10
734+
errors: list[Exception] = []
735+
736+
async def do_write(i: int):
737+
try:
738+
await transport.write(f'{{"msg": {i}}}\n')
739+
except Exception as e:
740+
errors.append(e)
741+
742+
async with anyio.create_task_group() as tg:
743+
for i in range(num_writes):
744+
tg.start_soon(do_write, i)
745+
746+
# All writes should succeed - the lock serializes them
747+
assert len(errors) == 0, f"Got errors: {errors}"
748+
finally:
749+
process.terminate()
750+
await process.wait()
751+
752+
anyio.run(_test, backend="trio")
753+
754+
def test_concurrent_writes_fail_without_lock(self):
755+
"""Verify that without the lock, concurrent writes cause BusyResourceError.
756+
757+
Uses a real subprocess with the same stream setup as production.
758+
"""
759+
760+
async def _test():
761+
import sys
762+
from contextlib import asynccontextmanager
763+
from subprocess import PIPE
764+
765+
from anyio.streams.text import TextSendStream
766+
767+
# Create a real subprocess that consumes stdin (cross-platform)
768+
process = await anyio.open_process(
769+
[sys.executable, "-c", "import sys; sys.stdin.read()"],
770+
stdin=PIPE,
771+
stdout=PIPE,
772+
stderr=PIPE,
773+
)
774+
775+
try:
776+
transport = SubprocessCLITransport(
777+
prompt="test",
778+
options=ClaudeAgentOptions(cli_path="/usr/bin/claude"),
779+
)
780+
781+
# Same setup as production
782+
transport._ready = True
783+
transport._process = MagicMock(returncode=None)
784+
transport._stdin_stream = TextSendStream(process.stdin)
785+
786+
# Replace lock with no-op to trigger the race condition
787+
class NoOpLock:
788+
@asynccontextmanager
789+
async def __call__(self):
790+
yield
791+
792+
async def __aenter__(self):
793+
return self
794+
795+
async def __aexit__(self, *args):
796+
pass
797+
798+
transport._write_lock = NoOpLock()
799+
800+
# Spawn concurrent writes - should fail without lock
801+
num_writes = 10
802+
errors: list[Exception] = []
803+
804+
async def do_write(i: int):
805+
try:
806+
await transport.write(f'{{"msg": {i}}}\n')
807+
except Exception as e:
808+
errors.append(e)
809+
810+
async with anyio.create_task_group() as tg:
811+
for i in range(num_writes):
812+
tg.start_soon(do_write, i)
813+
814+
# Should have gotten errors due to concurrent access
815+
assert len(errors) > 0, (
816+
"Expected errors from concurrent access, but got none"
817+
)
818+
819+
# Check that at least one error mentions the concurrent access
820+
error_strs = [str(e) for e in errors]
821+
assert any("another task" in s for s in error_strs), (
822+
f"Expected 'another task' error, got: {error_strs}"
823+
)
824+
finally:
825+
process.terminate()
826+
await process.wait()
827+
828+
anyio.run(_test, backend="trio")

0 commit comments

Comments
 (0)