Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ def __init__(
self._closed = False
self._initialization_result: dict[str, Any] | None = None

# Track first result for proper stream closure with SDK MCP servers
self._first_result_event = anyio.Event()
self._stream_close_timeout = (
float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0
) # Convert ms to seconds

async def initialize(self) -> dict[str, Any] | None:
"""Initialize control protocol if in streaming mode.

Expand Down Expand Up @@ -195,6 +201,10 @@ async def _read_messages(self) -> None:
# TODO: Implement cancellation support
continue

# Track results for proper stream closure
if msg_type == "result":
self._first_result_event.set()

# Regular SDK messages go to the stream
await self._message_send.send(message)

Expand Down Expand Up @@ -525,13 +535,35 @@ async def set_model(self, model: str | None) -> None:
)

async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
"""Stream input messages to transport."""
"""Stream input messages to transport.

If SDK MCP servers or hooks are present, waits for the first result
before closing stdin to allow bidirectional control protocol communication.
"""
try:
async for message in stream:
if self._closed:
break
await self.transport.write(json.dumps(message) + "\n")
# After all messages sent, end input

# If we have SDK MCP servers or hooks that need bidirectional communication,
# wait for first result before closing the channel
has_hooks = bool(self.hooks)
if self.sdk_mcp_servers or has_hooks:
logger.debug(
f"Waiting for first result before closing stdin "
f"(sdk_mcp_servers={len(self.sdk_mcp_servers)}, has_hooks={has_hooks})"
)
try:
with anyio.move_on_after(self._stream_close_timeout):
await self._first_result_event.wait()
logger.debug("Received first result, closing input stream")
except Exception:
logger.debug(
"Timed out waiting for first result, closing input stream"
)

# After all messages sent (and result received if needed), end input
await self.transport.end_input()
except Exception as e:
logger.debug(f"Error streaming input: {e}")
Expand Down
Loading