44import logging
55import os
66import shutil
7- from collections .abc import AsyncIterator
7+ from collections .abc import AsyncIterator , AsyncIterable
88from pathlib import Path
99from subprocess import PIPE
10- from typing import Any
10+ from typing import Any , Union
1111
1212import anyio
1313from anyio .abc import Process
14- from anyio .streams .text import TextReceiveStream
14+ from anyio .streams .text import TextReceiveStream , TextSendStream
1515
1616from ..._errors import CLIConnectionError , CLINotFoundError , ProcessError
1717from ..._errors import CLIJSONDecodeError as SDKJSONDecodeError
2626class SubprocessCLITransport (Transport ):
2727 """Subprocess transport using Claude Code CLI."""
2828
29- def __init__ (self , cli_path : str | Path | None = None ):
29+
30+ def __init__ (
31+ self ,
32+ prompt : Union [str , AsyncIterable [dict [str , Any ]]],
33+ options : ClaudeCodeOptions ,
34+ cli_path : str | Path | None = None ,
35+ ):
36+ self ._prompt = prompt
37+ self ._is_streaming = not isinstance (prompt , str )
38+ self ._options = options
3039 self ._cli_path = str (cli_path ) if cli_path else self ._find_cli ()
31- self ._prompt : str | None = None
32- self ._options : ClaudeCodeOptions | None = None
3340 self ._cwd : str | None = None
3441 self ._process : Process | None = None
3542 self ._stdout_stream : TextReceiveStream | None = None
3643 self ._stderr_stream : TextReceiveStream | None = None
44+ self ._stdin_stream : TextSendStream | None = None
45+ self ._pending_control_responses : dict [str , Any ] = {}
46+ self ._request_counter = 0
3747
3848 def configure (self , prompt : str , options : ClaudeCodeOptions ) -> None :
3949 """Configure transport with prompt and options."""
@@ -120,7 +130,14 @@ def _build_command(self) -> list[str]:
120130 ["--mcp-config" , json .dumps ({"mcpServers" : self ._options .mcp_servers })]
121131 )
122132
123- cmd .extend (["--print" , self ._prompt ])
133+ # Add prompt handling based on mode
134+ if self ._is_streaming :
135+ # Streaming mode: use --input-format stream-json
136+ cmd .extend (["--input-format" , "stream-json" ])
137+ else :
138+ # String mode: use --print with the prompt
139+ cmd .extend (["--print" , self ._prompt ])
140+
124141 return cmd
125142
126143 async def connect (self ) -> None :
@@ -130,9 +147,10 @@ async def connect(self) -> None:
130147
131148 cmd = self ._build_command ()
132149 try :
150+ # Enable stdin pipe for both modes (but we'll close it for string mode)
133151 self ._process = await anyio .open_process (
134152 cmd ,
135- stdin = None ,
153+ stdin = PIPE ,
136154 stdout = PIPE ,
137155 stderr = PIPE ,
138156 cwd = self ._cwd ,
@@ -143,6 +161,18 @@ async def connect(self) -> None:
143161 self ._stdout_stream = TextReceiveStream (self ._process .stdout )
144162 if self ._process .stderr :
145163 self ._stderr_stream = TextReceiveStream (self ._process .stderr )
164+
165+ # Handle stdin based on mode
166+ if self ._is_streaming :
167+ # Streaming mode: keep stdin open and start streaming task
168+ if self ._process .stdin :
169+ self ._stdin_stream = TextSendStream (self ._process .stdin )
170+ # Start streaming messages to stdin
171+ anyio .start_soon (self ._stream_to_stdin )
172+ else :
173+ # String mode: close stdin immediately (backward compatible)
174+ if self ._process .stdin :
175+ await self ._process .stdin .aclose ()
146176
147177 except FileNotFoundError as e :
148178 # Check if the error comes from the working directory or the CLI
@@ -173,10 +203,32 @@ async def disconnect(self) -> None:
173203 self ._process = None
174204 self ._stdout_stream = None
175205 self ._stderr_stream = None
206+ self ._stdin_stream = None
176207
177208 async def send_request (self , messages : list [Any ], options : dict [str , Any ]) -> None :
178209 """Not used for CLI transport - args passed via command line."""
179210
211+ async def _stream_to_stdin (self ) -> None :
212+ """Stream messages to stdin for streaming mode."""
213+ if not self ._stdin_stream or not isinstance (self ._prompt , AsyncIterable ):
214+ return
215+
216+ try :
217+ async for message in self ._prompt :
218+ if not self ._stdin_stream :
219+ break
220+ await self ._stdin_stream .send (json .dumps (message ) + "\n " )
221+
222+ # Close stdin when done
223+ if self ._stdin_stream :
224+ await self ._stdin_stream .aclose ()
225+ self ._stdin_stream = None
226+ except Exception as e :
227+ logger .debug (f"Error streaming to stdin: { e } " )
228+ if self ._stdin_stream :
229+ await self ._stdin_stream .aclose ()
230+ self ._stdin_stream = None
231+
180232 async def receive_messages (self ) -> AsyncIterator [dict [str , Any ]]:
181233 """Receive messages from CLI."""
182234 if not self ._process or not self ._stdout_stream :
@@ -217,6 +269,15 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
217269 try :
218270 data = json .loads (json_buffer )
219271 json_buffer = ""
272+
273+ # Handle control responses separately
274+ if data .get ("type" ) == "control_response" :
275+ request_id = data .get ("response" , {}).get ("request_id" )
276+ if request_id and request_id in self ._pending_control_responses :
277+ # Store the response for the pending request
278+ self ._pending_control_responses [request_id ] = data .get ("response" , {})
279+ continue
280+
220281 try :
221282 yield data
222283 except GeneratorExit :
@@ -284,3 +345,47 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
284345 def is_connected (self ) -> bool :
285346 """Check if subprocess is running."""
286347 return self ._process is not None and self ._process .returncode is None
348+
349+ async def interrupt (self ) -> None :
350+ """Send interrupt control request (only works in streaming mode)."""
351+ if not self ._is_streaming :
352+ raise CLIConnectionError ("Interrupt requires streaming mode (AsyncIterable prompt)" )
353+
354+ if not self ._stdin_stream :
355+ raise CLIConnectionError ("Not connected or stdin not available" )
356+
357+ await self ._send_control_request ({"subtype" : "interrupt" })
358+
359+ async def _send_control_request (self , request : dict [str , Any ]) -> dict [str , Any ]:
360+ """Send a control request and wait for response."""
361+ if not self ._stdin_stream :
362+ raise CLIConnectionError ("Stdin not available" )
363+
364+ # Generate unique request ID
365+ self ._request_counter += 1
366+ request_id = f"req_{ self ._request_counter } _{ os .urandom (4 ).hex ()} "
367+
368+ # Build control request
369+ control_request = {
370+ "type" : "control_request" ,
371+ "request_id" : request_id ,
372+ "request" : request
373+ }
374+
375+ # Send request
376+ await self ._stdin_stream .send (json .dumps (control_request ) + "\n " )
377+
378+ # Wait for response with timeout
379+ try :
380+ with anyio .fail_after (30.0 ): # 30 second timeout
381+ while request_id not in self ._pending_control_responses :
382+ await anyio .sleep (0.1 )
383+
384+ response = self ._pending_control_responses .pop (request_id )
385+
386+ if response .get ("subtype" ) == "error" :
387+ raise CLIConnectionError (f"Control request failed: { response .get ('error' )} " )
388+
389+ return response
390+ except TimeoutError :
391+ raise CLIConnectionError ("Control request timed out" ) from None
0 commit comments