diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 566e3161..e48995f9 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -107,6 +107,12 @@ def __init__( self._closed = False self._initialization_result: dict[str, Any] | None = None + # Track first result for proper stream closure with SDK MCP servers + self._first_result_event = anyio.Event() + self._stream_close_timeout = ( + float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0 + ) # Convert ms to seconds + async def initialize(self) -> dict[str, Any] | None: """Initialize control protocol if in streaming mode. @@ -195,6 +201,10 @@ async def _read_messages(self) -> None: # TODO: Implement cancellation support continue + # Track results for proper stream closure + if msg_type == "result": + self._first_result_event.set() + # Regular SDK messages go to the stream await self._message_send.send(message) @@ -525,13 +535,35 @@ async def set_model(self, model: str | None) -> None: ) async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None: - """Stream input messages to transport.""" + """Stream input messages to transport. + + If SDK MCP servers or hooks are present, waits for the first result + before closing stdin to allow bidirectional control protocol communication. + """ try: async for message in stream: if self._closed: break await self.transport.write(json.dumps(message) + "\n") - # After all messages sent, end input + + # If we have SDK MCP servers or hooks that need bidirectional communication, + # wait for first result before closing the channel + has_hooks = bool(self.hooks) + if self.sdk_mcp_servers or has_hooks: + logger.debug( + f"Waiting for first result before closing stdin " + f"(sdk_mcp_servers={len(self.sdk_mcp_servers)}, has_hooks={has_hooks})" + ) + try: + with anyio.move_on_after(self._stream_close_timeout): + await self._first_result_event.wait() + logger.debug("Received first result, closing input stream") + except Exception: + logger.debug( + "Timed out waiting for first result, closing input stream" + ) + + # After all messages sent (and result received if needed), end input await self.transport.end_input() except Exception as e: logger.debug(f"Error streaming input: {e}")