diff --git a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py index 26bd2ec4..c7c74203 100644 --- a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py @@ -65,6 +65,7 @@ def __init__( else _DEFAULT_MAX_BUFFER_SIZE ) self._temp_files: list[str] = [] # Track temporary files for cleanup + self._write_lock: anyio.Lock = anyio.Lock() def _find_cli(self) -> str: """Find Claude Code CLI binary.""" @@ -471,8 +472,6 @@ async def _handle_stderr(self) -> None: async def close(self) -> None: """Close the transport and clean up resources.""" - self._ready = False - # Clean up temporary files first (before early return) for temp_file in self._temp_files: with suppress(Exception): @@ -480,6 +479,7 @@ async def close(self) -> None: self._temp_files.clear() if not self._process: + self._ready = False return # Close stderr task group if active @@ -489,21 +489,19 @@ async def close(self) -> None: await self._stderr_task_group.__aexit__(None, None, None) self._stderr_task_group = None - # Close streams - if self._stdin_stream: - with suppress(Exception): - await self._stdin_stream.aclose() - self._stdin_stream = None + # Close stdin stream (acquire lock to prevent race with concurrent writes) + async with self._write_lock: + self._ready = False # Set inside lock to prevent TOCTOU with write() + if self._stdin_stream: + with suppress(Exception): + await self._stdin_stream.aclose() + self._stdin_stream = None if self._stderr_stream: with suppress(Exception): await self._stderr_stream.aclose() self._stderr_stream = None - if self._process.stdin: - with suppress(Exception): - await self._process.stdin.aclose() - # Terminate and wait for process if self._process.returncode is None: with suppress(ProcessLookupError): @@ -521,37 +519,37 @@ async def close(self) -> None: async def write(self, data: str) -> None: """Write raw data to the transport.""" - # Check if ready (like TypeScript) - if not self._ready or not self._stdin_stream: - raise CLIConnectionError("ProcessTransport is not ready for writing") - - # Check if process is still alive (like TypeScript) - if self._process and self._process.returncode is not None: - raise CLIConnectionError( - f"Cannot write to terminated process (exit code: {self._process.returncode})" - ) + async with self._write_lock: + # All checks inside lock to prevent TOCTOU races with close()/end_input() + if not self._ready or not self._stdin_stream: + raise CLIConnectionError("ProcessTransport is not ready for writing") + + if self._process and self._process.returncode is not None: + raise CLIConnectionError( + f"Cannot write to terminated process (exit code: {self._process.returncode})" + ) - # Check for exit errors (like TypeScript) - if self._exit_error: - raise CLIConnectionError( - f"Cannot write to process that exited with error: {self._exit_error}" - ) from self._exit_error + if self._exit_error: + raise CLIConnectionError( + f"Cannot write to process that exited with error: {self._exit_error}" + ) from self._exit_error - try: - await self._stdin_stream.send(data) - except Exception as e: - self._ready = False # Mark as not ready (like TypeScript) - self._exit_error = CLIConnectionError( - f"Failed to write to process stdin: {e}" - ) - raise self._exit_error from e + try: + await self._stdin_stream.send(data) + except Exception as e: + self._ready = False + self._exit_error = CLIConnectionError( + f"Failed to write to process stdin: {e}" + ) + raise self._exit_error from e async def end_input(self) -> None: """End the input stream (close stdin).""" - if self._stdin_stream: - with suppress(Exception): - await self._stdin_stream.aclose() - self._stdin_stream = None + async with self._write_lock: + if self._stdin_stream: + with suppress(Exception): + await self._stdin_stream.aclose() + self._stdin_stream = None def read_messages(self) -> AsyncIterator[dict[str, Any]]: """Read and parse messages from the transport.""" diff --git a/tests/test_transport.py b/tests/test_transport.py index c634fc2e..fe9b6b22 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -693,3 +693,136 @@ def test_build_command_without_tools(self): cmd = transport._build_command() assert "--tools" not in cmd + + def test_concurrent_writes_are_serialized(self): + """Test that concurrent write() calls are serialized by the lock. + + When parallel subagents invoke MCP tools, they trigger concurrent write() + calls. Without the _write_lock, trio raises BusyResourceError. + + Uses a real subprocess with the same stream setup as production: + process.stdin -> TextSendStream + """ + + async def _test(): + import sys + from subprocess import PIPE + + from anyio.streams.text import TextSendStream + + # Create a real subprocess that consumes stdin (cross-platform) + process = await anyio.open_process( + [sys.executable, "-c", "import sys; sys.stdin.read()"], + stdin=PIPE, + stdout=PIPE, + stderr=PIPE, + ) + + try: + transport = SubprocessCLITransport( + prompt="test", + options=ClaudeAgentOptions(cli_path="/usr/bin/claude"), + ) + + # Same setup as production: TextSendStream wrapping process.stdin + transport._ready = True + transport._process = MagicMock(returncode=None) + transport._stdin_stream = TextSendStream(process.stdin) + + # Spawn concurrent writes - the lock should serialize them + num_writes = 10 + errors: list[Exception] = [] + + async def do_write(i: int): + try: + await transport.write(f'{{"msg": {i}}}\n') + except Exception as e: + errors.append(e) + + async with anyio.create_task_group() as tg: + for i in range(num_writes): + tg.start_soon(do_write, i) + + # All writes should succeed - the lock serializes them + assert len(errors) == 0, f"Got errors: {errors}" + finally: + process.terminate() + await process.wait() + + anyio.run(_test, backend="trio") + + def test_concurrent_writes_fail_without_lock(self): + """Verify that without the lock, concurrent writes cause BusyResourceError. + + Uses a real subprocess with the same stream setup as production. + """ + + async def _test(): + import sys + from contextlib import asynccontextmanager + from subprocess import PIPE + + from anyio.streams.text import TextSendStream + + # Create a real subprocess that consumes stdin (cross-platform) + process = await anyio.open_process( + [sys.executable, "-c", "import sys; sys.stdin.read()"], + stdin=PIPE, + stdout=PIPE, + stderr=PIPE, + ) + + try: + transport = SubprocessCLITransport( + prompt="test", + options=ClaudeAgentOptions(cli_path="/usr/bin/claude"), + ) + + # Same setup as production + transport._ready = True + transport._process = MagicMock(returncode=None) + transport._stdin_stream = TextSendStream(process.stdin) + + # Replace lock with no-op to trigger the race condition + class NoOpLock: + @asynccontextmanager + async def __call__(self): + yield + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + transport._write_lock = NoOpLock() + + # Spawn concurrent writes - should fail without lock + num_writes = 10 + errors: list[Exception] = [] + + async def do_write(i: int): + try: + await transport.write(f'{{"msg": {i}}}\n') + except Exception as e: + errors.append(e) + + async with anyio.create_task_group() as tg: + for i in range(num_writes): + tg.start_soon(do_write, i) + + # Should have gotten errors due to concurrent access + assert len(errors) > 0, ( + "Expected errors from concurrent access, but got none" + ) + + # Check that at least one error mentions the concurrent access + error_strs = [str(e) for e in errors] + assert any("another task" in s for s in error_strs), ( + f"Expected 'another task' error, got: {error_strs}" + ) + finally: + process.terminate() + await process.wait() + + anyio.run(_test, backend="trio")