Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 37 additions & 36 deletions src/claude_code_sdk/_internal/transport/subprocess_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import os
import shutil
import tempfile
from collections import deque
from collections.abc import AsyncIterable, AsyncIterator
from pathlib import Path
from subprocess import PIPE
Expand Down Expand Up @@ -46,6 +48,7 @@ def __init__(
self._request_counter = 0
self._close_stdin_after_prompt = close_stdin_after_prompt
self._task_group: anyio.abc.TaskGroup | None = None
self._stderr_file: Any = None # tempfile.NamedTemporaryFile

def _find_cli(self) -> str:
"""Find Claude Code CLI binary."""
Expand Down Expand Up @@ -143,20 +146,24 @@ async def connect(self) -> None:

cmd = self._build_command()
try:
# Create a temp file for stderr to avoid pipe buffer deadlock
# We can't use context manager as we need it for the subprocess lifetime
self._stderr_file = tempfile.NamedTemporaryFile( # noqa: SIM115
mode="w+", prefix="claude_stderr_", suffix=".log", delete=False
)

# Enable stdin pipe for both modes (but we'll close it for string mode)
self._process = await anyio.open_process(
cmd,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
stderr=self._stderr_file,
cwd=self._cwd,
env={**os.environ, "CLAUDE_CODE_ENTRYPOINT": "sdk-py"},
)

if self._process.stdout:
self._stdout_stream = TextReceiveStream(self._process.stdout)
if self._process.stderr:
self._stderr_stream = TextReceiveStream(self._process.stderr)

# Handle stdin based on mode
if self._is_streaming:
Expand Down Expand Up @@ -204,6 +211,15 @@ async def disconnect(self) -> None:
except ProcessLookupError:
pass

# Clean up temp file
if self._stderr_file:
try:
self._stderr_file.close()
Path(self._stderr_file.name).unlink()
except Exception:
pass
self._stderr_file = None

self._process = None
self._stdout_stream = None
self._stderr_stream = None
Expand Down Expand Up @@ -257,10 +273,6 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
if not self._process or not self._stdout_stream:
raise CLIConnectionError("Not connected")

# Safety constants
max_stderr_size = 10 * 1024 * 1024 # 10MB
stderr_timeout = 30.0 # 30 seconds

json_buffer = ""

# Process stdout messages first
Expand Down Expand Up @@ -318,36 +330,19 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
# Client disconnected - still need to clean up
pass

# Process stderr with safety limits
stderr_lines = []
stderr_size = 0

if self._stderr_stream:
# Read stderr from temp file (keep only last N lines for memory efficiency)
stderr_lines: deque[str] = deque(maxlen=100) # Keep last 100 lines
if self._stderr_file:
try:
# Use timeout to prevent hanging
with anyio.fail_after(stderr_timeout):
async for line in self._stderr_stream:
line_text = line.strip()
line_size = len(line_text)

# Enforce memory limit
if stderr_size + line_size > max_stderr_size:
stderr_lines.append(
f"[stderr truncated after {stderr_size} bytes]"
)
# Drain rest of stream without storing
async for _ in self._stderr_stream:
pass
break

# Flush any pending writes
self._stderr_file.flush()
# Read from the beginning
self._stderr_file.seek(0)
for line in self._stderr_file:
line_text = line.strip()
if line_text:
stderr_lines.append(line_text)
stderr_size += line_size

except TimeoutError:
stderr_lines.append(
f"[stderr collection timed out after {stderr_timeout}s]"
)
except anyio.ClosedResourceError:
except Exception:
pass

# Check process completion and handle errors
Expand All @@ -356,7 +351,13 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
except Exception:
returncode = -1

stderr_output = "\n".join(stderr_lines) if stderr_lines else ""
# Convert deque to string for error reporting
stderr_output = "\n".join(list(stderr_lines)) if stderr_lines else ""
if len(stderr_lines) == stderr_lines.maxlen:
stderr_output = (
f"[stderr truncated, showing last {stderr_lines.maxlen} lines]\n"
+ stderr_output
)

# Use exit code for error detection, not string matching
if returncode is not None and returncode != 0:
Expand Down