Skip to content

Commit c0841d8

Browse files
dicksontsairushilpatel0
authored andcommitted
Initial implementation of bidi streaming
1 parent e12fc7c commit c0841d8

File tree

1 file changed

+113
-8
lines changed

1 file changed

+113
-8
lines changed

src/claude_code_sdk/_internal/transport/subprocess_cli.py

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import logging
55
import os
66
import shutil
7-
from collections.abc import AsyncIterator
7+
from collections.abc import AsyncIterator, AsyncIterable
88
from pathlib import Path
99
from subprocess import PIPE
10-
from typing import Any
10+
from typing import Any, Union
1111

1212
import anyio
1313
from anyio.abc import Process
14-
from anyio.streams.text import TextReceiveStream
14+
from anyio.streams.text import TextReceiveStream, TextSendStream
1515

1616
from ..._errors import CLIConnectionError, CLINotFoundError, ProcessError
1717
from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError
@@ -26,14 +26,24 @@
2626
class SubprocessCLITransport(Transport):
2727
"""Subprocess transport using Claude Code CLI."""
2828

29-
def __init__(self, cli_path: str | Path | None = None):
29+
30+
def __init__(
31+
self,
32+
prompt: Union[str, AsyncIterable[dict[str, Any]]],
33+
options: ClaudeCodeOptions,
34+
cli_path: str | Path | None = None,
35+
):
36+
self._prompt = prompt
37+
self._is_streaming = not isinstance(prompt, str)
38+
self._options = options
3039
self._cli_path = str(cli_path) if cli_path else self._find_cli()
31-
self._prompt: str | None = None
32-
self._options: ClaudeCodeOptions | None = None
3340
self._cwd: str | None = None
3441
self._process: Process | None = None
3542
self._stdout_stream: TextReceiveStream | None = None
3643
self._stderr_stream: TextReceiveStream | None = None
44+
self._stdin_stream: TextSendStream | None = None
45+
self._pending_control_responses: dict[str, Any] = {}
46+
self._request_counter = 0
3747

3848
def configure(self, prompt: str, options: ClaudeCodeOptions) -> None:
3949
"""Configure transport with prompt and options."""
@@ -120,7 +130,14 @@ def _build_command(self) -> list[str]:
120130
["--mcp-config", json.dumps({"mcpServers": self._options.mcp_servers})]
121131
)
122132

123-
cmd.extend(["--print", self._prompt])
133+
# Add prompt handling based on mode
134+
if self._is_streaming:
135+
# Streaming mode: use --input-format stream-json
136+
cmd.extend(["--input-format", "stream-json"])
137+
else:
138+
# String mode: use --print with the prompt
139+
cmd.extend(["--print", self._prompt])
140+
124141
return cmd
125142

126143
async def connect(self) -> None:
@@ -130,9 +147,10 @@ async def connect(self) -> None:
130147

131148
cmd = self._build_command()
132149
try:
150+
# Enable stdin pipe for both modes (but we'll close it for string mode)
133151
self._process = await anyio.open_process(
134152
cmd,
135-
stdin=None,
153+
stdin=PIPE,
136154
stdout=PIPE,
137155
stderr=PIPE,
138156
cwd=self._cwd,
@@ -143,6 +161,18 @@ async def connect(self) -> None:
143161
self._stdout_stream = TextReceiveStream(self._process.stdout)
144162
if self._process.stderr:
145163
self._stderr_stream = TextReceiveStream(self._process.stderr)
164+
165+
# Handle stdin based on mode
166+
if self._is_streaming:
167+
# Streaming mode: keep stdin open and start streaming task
168+
if self._process.stdin:
169+
self._stdin_stream = TextSendStream(self._process.stdin)
170+
# Start streaming messages to stdin
171+
anyio.start_soon(self._stream_to_stdin)
172+
else:
173+
# String mode: close stdin immediately (backward compatible)
174+
if self._process.stdin:
175+
await self._process.stdin.aclose()
146176

147177
except FileNotFoundError as e:
148178
# Check if the error comes from the working directory or the CLI
@@ -173,10 +203,32 @@ async def disconnect(self) -> None:
173203
self._process = None
174204
self._stdout_stream = None
175205
self._stderr_stream = None
206+
self._stdin_stream = None
176207

177208
async def send_request(self, messages: list[Any], options: dict[str, Any]) -> None:
178209
"""Not used for CLI transport - args passed via command line."""
179210

211+
async def _stream_to_stdin(self) -> None:
212+
"""Stream messages to stdin for streaming mode."""
213+
if not self._stdin_stream or not isinstance(self._prompt, AsyncIterable):
214+
return
215+
216+
try:
217+
async for message in self._prompt:
218+
if not self._stdin_stream:
219+
break
220+
await self._stdin_stream.send(json.dumps(message) + "\n")
221+
222+
# Close stdin when done
223+
if self._stdin_stream:
224+
await self._stdin_stream.aclose()
225+
self._stdin_stream = None
226+
except Exception as e:
227+
logger.debug(f"Error streaming to stdin: {e}")
228+
if self._stdin_stream:
229+
await self._stdin_stream.aclose()
230+
self._stdin_stream = None
231+
180232
async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
181233
"""Receive messages from CLI."""
182234
if not self._process or not self._stdout_stream:
@@ -217,6 +269,15 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
217269
try:
218270
data = json.loads(json_buffer)
219271
json_buffer = ""
272+
273+
# Handle control responses separately
274+
if data.get("type") == "control_response":
275+
request_id = data.get("response", {}).get("request_id")
276+
if request_id and request_id in self._pending_control_responses:
277+
# Store the response for the pending request
278+
self._pending_control_responses[request_id] = data.get("response", {})
279+
continue
280+
220281
try:
221282
yield data
222283
except GeneratorExit:
@@ -284,3 +345,47 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
284345
def is_connected(self) -> bool:
285346
"""Check if subprocess is running."""
286347
return self._process is not None and self._process.returncode is None
348+
349+
async def interrupt(self) -> None:
350+
"""Send interrupt control request (only works in streaming mode)."""
351+
if not self._is_streaming:
352+
raise CLIConnectionError("Interrupt requires streaming mode (AsyncIterable prompt)")
353+
354+
if not self._stdin_stream:
355+
raise CLIConnectionError("Not connected or stdin not available")
356+
357+
await self._send_control_request({"subtype": "interrupt"})
358+
359+
async def _send_control_request(self, request: dict[str, Any]) -> dict[str, Any]:
360+
"""Send a control request and wait for response."""
361+
if not self._stdin_stream:
362+
raise CLIConnectionError("Stdin not available")
363+
364+
# Generate unique request ID
365+
self._request_counter += 1
366+
request_id = f"req_{self._request_counter}_{os.urandom(4).hex()}"
367+
368+
# Build control request
369+
control_request = {
370+
"type": "control_request",
371+
"request_id": request_id,
372+
"request": request
373+
}
374+
375+
# Send request
376+
await self._stdin_stream.send(json.dumps(control_request) + "\n")
377+
378+
# Wait for response with timeout
379+
try:
380+
with anyio.fail_after(30.0): # 30 second timeout
381+
while request_id not in self._pending_control_responses:
382+
await anyio.sleep(0.1)
383+
384+
response = self._pending_control_responses.pop(request_id)
385+
386+
if response.get("subtype") == "error":
387+
raise CLIConnectionError(f"Control request failed: {response.get('error')}")
388+
389+
return response
390+
except TimeoutError:
391+
raise CLIConnectionError("Control request timed out") from None

0 commit comments

Comments
 (0)