-
Notifications
You must be signed in to change notification settings - Fork 539
fix: add write lock to prevent concurrent transport writes #391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,7 @@ def __init__( | |
| else _DEFAULT_MAX_BUFFER_SIZE | ||
| ) | ||
| self._temp_files: list[str] = [] # Track temporary files for cleanup | ||
| self._write_lock: anyio.Lock = anyio.Lock() | ||
|
|
||
| def _find_cli(self) -> str: | ||
| """Find Claude Code CLI binary.""" | ||
|
|
@@ -471,15 +472,14 @@ async def _handle_stderr(self) -> None: | |
|
|
||
| async def close(self) -> None: | ||
| """Close the transport and clean up resources.""" | ||
| self._ready = False | ||
|
|
||
| # Clean up temporary files first (before early return) | ||
| for temp_file in self._temp_files: | ||
| with suppress(Exception): | ||
| Path(temp_file).unlink(missing_ok=True) | ||
| self._temp_files.clear() | ||
|
|
||
| if not self._process: | ||
| self._ready = False | ||
| return | ||
|
|
||
| # Close stderr task group if active | ||
|
|
@@ -489,21 +489,19 @@ async def close(self) -> None: | |
| await self._stderr_task_group.__aexit__(None, None, None) | ||
| self._stderr_task_group = None | ||
|
|
||
| # Close streams | ||
| if self._stdin_stream: | ||
| with suppress(Exception): | ||
| await self._stdin_stream.aclose() | ||
| self._stdin_stream = None | ||
| # Close stdin stream (acquire lock to prevent race with concurrent writes) | ||
| async with self._write_lock: | ||
| self._ready = False # Set inside lock to prevent TOCTOU with write() | ||
| if self._stdin_stream: | ||
| with suppress(Exception): | ||
| await self._stdin_stream.aclose() | ||
| self._stdin_stream = None | ||
|
|
||
| if self._stderr_stream: | ||
| with suppress(Exception): | ||
| await self._stderr_stream.aclose() | ||
| self._stderr_stream = None | ||
|
|
||
| if self._process.stdin: | ||
| with suppress(Exception): | ||
| await self._process.stdin.aclose() | ||
|
|
||
| # Terminate and wait for process | ||
| if self._process.returncode is None: | ||
| with suppress(ProcessLookupError): | ||
|
|
@@ -521,37 +519,37 @@ async def close(self) -> None: | |
|
|
||
| async def write(self, data: str) -> None: | ||
| """Write raw data to the transport.""" | ||
| # Check if ready (like TypeScript) | ||
| if not self._ready or not self._stdin_stream: | ||
| raise CLIConnectionError("ProcessTransport is not ready for writing") | ||
|
|
||
| # Check if process is still alive (like TypeScript) | ||
| if self._process and self._process.returncode is not None: | ||
| raise CLIConnectionError( | ||
| f"Cannot write to terminated process (exit code: {self._process.returncode})" | ||
| ) | ||
| async with self._write_lock: | ||
| # All checks inside lock to prevent TOCTOU races with close()/end_input() | ||
| if not self._ready or not self._stdin_stream: | ||
| raise CLIConnectionError("ProcessTransport is not ready for writing") | ||
|
|
||
| if self._process and self._process.returncode is not None: | ||
| raise CLIConnectionError( | ||
| f"Cannot write to terminated process (exit code: {self._process.returncode})" | ||
| ) | ||
|
|
||
| # Check for exit errors (like TypeScript) | ||
| if self._exit_error: | ||
| raise CLIConnectionError( | ||
| f"Cannot write to process that exited with error: {self._exit_error}" | ||
| ) from self._exit_error | ||
| if self._exit_error: | ||
| raise CLIConnectionError( | ||
| f"Cannot write to process that exited with error: {self._exit_error}" | ||
| ) from self._exit_error | ||
|
|
||
| try: | ||
| await self._stdin_stream.send(data) | ||
| except Exception as e: | ||
| self._ready = False # Mark as not ready (like TypeScript) | ||
| self._exit_error = CLIConnectionError( | ||
| f"Failed to write to process stdin: {e}" | ||
| ) | ||
| raise self._exit_error from e | ||
| try: | ||
| await self._stdin_stream.send(data) | ||
| except Exception as e: | ||
| self._ready = False | ||
| self._exit_error = CLIConnectionError( | ||
| f"Failed to write to process stdin: {e}" | ||
| ) | ||
| raise self._exit_error from e | ||
|
Comment on lines
520
to
+544
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2. Write Method with LockThe entire
Related: See comment 1 for lock initialization. 🤖 Generated with Claude Code |
||
|
|
||
| async def end_input(self) -> None: | ||
| """End the input stream (close stdin).""" | ||
| if self._stdin_stream: | ||
| with suppress(Exception): | ||
| await self._stdin_stream.aclose() | ||
| self._stdin_stream = None | ||
| async with self._write_lock: | ||
| if self._stdin_stream: | ||
| with suppress(Exception): | ||
| await self._stdin_stream.aclose() | ||
| self._stdin_stream = None | ||
|
|
||
| def read_messages(self) -> AsyncIterator[dict[str, Any]]: | ||
| """Read and parse messages from the transport.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -693,3 +693,136 @@ def test_build_command_without_tools(self): | |
|
|
||
| cmd = transport._build_command() | ||
| assert "--tools" not in cmd | ||
|
|
||
| def test_concurrent_writes_are_serialized(self): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 3. Concurrent Write TestsTwo tests verify the lock works correctly:
Both tests use a real subprocess with 🤖 Generated with Claude Code |
||
| """Test that concurrent write() calls are serialized by the lock. | ||
|
|
||
| When parallel subagents invoke MCP tools, they trigger concurrent write() | ||
| calls. Without the _write_lock, trio raises BusyResourceError. | ||
|
|
||
| Uses a real subprocess with the same stream setup as production: | ||
| process.stdin -> TextSendStream | ||
| """ | ||
|
|
||
| async def _test(): | ||
| import sys | ||
| from subprocess import PIPE | ||
|
|
||
| from anyio.streams.text import TextSendStream | ||
|
|
||
| # Create a real subprocess that consumes stdin (cross-platform) | ||
| process = await anyio.open_process( | ||
| [sys.executable, "-c", "import sys; sys.stdin.read()"], | ||
| stdin=PIPE, | ||
| stdout=PIPE, | ||
| stderr=PIPE, | ||
| ) | ||
|
|
||
| try: | ||
| transport = SubprocessCLITransport( | ||
| prompt="test", | ||
| options=ClaudeAgentOptions(cli_path="/usr/bin/claude"), | ||
| ) | ||
|
|
||
| # Same setup as production: TextSendStream wrapping process.stdin | ||
| transport._ready = True | ||
| transport._process = MagicMock(returncode=None) | ||
| transport._stdin_stream = TextSendStream(process.stdin) | ||
|
|
||
| # Spawn concurrent writes - the lock should serialize them | ||
| num_writes = 10 | ||
| errors: list[Exception] = [] | ||
|
|
||
| async def do_write(i: int): | ||
| try: | ||
| await transport.write(f'{{"msg": {i}}}\n') | ||
| except Exception as e: | ||
| errors.append(e) | ||
|
|
||
| async with anyio.create_task_group() as tg: | ||
| for i in range(num_writes): | ||
| tg.start_soon(do_write, i) | ||
|
|
||
| # All writes should succeed - the lock serializes them | ||
| assert len(errors) == 0, f"Got errors: {errors}" | ||
| finally: | ||
| process.terminate() | ||
| await process.wait() | ||
|
|
||
| anyio.run(_test, backend="trio") | ||
|
|
||
| def test_concurrent_writes_fail_without_lock(self): | ||
| """Verify that without the lock, concurrent writes cause BusyResourceError. | ||
|
|
||
| Uses a real subprocess with the same stream setup as production. | ||
| """ | ||
|
|
||
| async def _test(): | ||
| import sys | ||
| from contextlib import asynccontextmanager | ||
| from subprocess import PIPE | ||
|
|
||
| from anyio.streams.text import TextSendStream | ||
|
|
||
| # Create a real subprocess that consumes stdin (cross-platform) | ||
| process = await anyio.open_process( | ||
| [sys.executable, "-c", "import sys; sys.stdin.read()"], | ||
| stdin=PIPE, | ||
| stdout=PIPE, | ||
| stderr=PIPE, | ||
| ) | ||
|
|
||
| try: | ||
| transport = SubprocessCLITransport( | ||
| prompt="test", | ||
| options=ClaudeAgentOptions(cli_path="/usr/bin/claude"), | ||
| ) | ||
|
|
||
| # Same setup as production | ||
| transport._ready = True | ||
| transport._process = MagicMock(returncode=None) | ||
| transport._stdin_stream = TextSendStream(process.stdin) | ||
|
|
||
| # Replace lock with no-op to trigger the race condition | ||
| class NoOpLock: | ||
| @asynccontextmanager | ||
| async def __call__(self): | ||
| yield | ||
|
|
||
| async def __aenter__(self): | ||
| return self | ||
|
|
||
| async def __aexit__(self, *args): | ||
| pass | ||
|
|
||
| transport._write_lock = NoOpLock() | ||
|
|
||
| # Spawn concurrent writes - should fail without lock | ||
| num_writes = 10 | ||
| errors: list[Exception] = [] | ||
|
|
||
| async def do_write(i: int): | ||
| try: | ||
| await transport.write(f'{{"msg": {i}}}\n') | ||
| except Exception as e: | ||
| errors.append(e) | ||
|
|
||
| async with anyio.create_task_group() as tg: | ||
| for i in range(num_writes): | ||
| tg.start_soon(do_write, i) | ||
|
|
||
| # Should have gotten errors due to concurrent access | ||
| assert len(errors) > 0, ( | ||
| "Expected errors from concurrent access, but got none" | ||
| ) | ||
|
|
||
| # Check that at least one error mentions the concurrent access | ||
| error_strs = [str(e) for e in errors] | ||
| assert any("another task" in s for s in error_strs), ( | ||
| f"Expected 'another task' error, got: {error_strs}" | ||
| ) | ||
| finally: | ||
| process.terminate() | ||
| await process.wait() | ||
|
|
||
| anyio.run(_test, backend="trio") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1. Write Lock Initialization
The lock serializes concurrent writes to stdin. When parallel subagents invoke MCP tools, multiple handlers try to write responses at the same time. Trio's
TextSendStream.send()isn't safe for concurrent use and raisesBusyResourceError.🤖 Generated with Claude Code