Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 12 additions & 24 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,6 @@ def __init__(
self._closed = False
self._initialization_result: dict[str, Any] | None = None

# Track first result for proper stream closure with SDK MCP servers
self._first_result_event = anyio.Event()
self._stream_close_timeout = (
float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0
) # Convert ms to seconds

async def initialize(self) -> dict[str, Any] | None:
"""Initialize control protocol if in streaming mode.

Expand Down Expand Up @@ -206,10 +200,6 @@ async def _read_messages(self) -> None:
# TODO: Implement cancellation support
continue

# Track results for proper stream closure
if msg_type == "result":
self._first_result_event.set()

# Regular SDK messages go to the stream
await self._message_send.send(message)

Expand Down Expand Up @@ -570,34 +560,32 @@ async def rewind_files(self, user_message_id: str) -> None:
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
"""Stream input messages to transport.

If SDK MCP servers or hooks are present, waits for the first result
before closing stdin to allow bidirectional control protocol communication.
If SDK MCP servers or hooks are present, keeps stdin open for the
entire session so the bidirectional control channel stays alive.
"""
try:
async for message in stream:
if self._closed:
break
await self.transport.write(json.dumps(message) + "\n")

# If we have SDK MCP servers or hooks that need bidirectional communication,
# wait for first result before closing the channel
# If we have SDK MCP servers or hooks that need bidirectional
# communication, keep stdin open for the entire session. Closing
# after the first result breaks hook callbacks on subsequent
# agent turns. The task group cancellation in close() will break
# us out, and transport.close() handles stdin cleanup.
has_hooks = bool(self.hooks)
if self.sdk_mcp_servers or has_hooks:
logger.debug(
f"Waiting for first result before closing stdin "
f"Keeping stdin open for bidirectional communication "
f"(sdk_mcp_servers={len(self.sdk_mcp_servers)}, has_hooks={has_hooks})"
)
try:
with anyio.move_on_after(self._stream_close_timeout):
await self._first_result_event.wait()
logger.debug("Received first result, closing input stream")
except Exception:
logger.debug(
"Timed out waiting for first result, closing input stream"
)
await anyio.sleep_forever()

# After all messages sent (and result received if needed), end input
# No hooks or MCP servers — safe to close stdin immediately.
await self.transport.end_input()
except anyio.get_cancelled_exc_class():
logger.debug("Stream input cancelled")
except Exception as e:
logger.debug(f"Error streaming input: {e}")

Expand Down
50 changes: 50 additions & 0 deletions tests/test_streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,3 +833,53 @@ async def mock_receive():
assert isinstance(messages[-1], ResultMessage)

anyio.run(_test)


class TestStreamInputStdinLifecycle:
"""Test that stream_input keeps stdin open when hooks need it."""

@pytest.mark.asyncio
async def test_stream_input_calls_end_input_without_hooks(self):
"""Without hooks, stream_input closes stdin after the generator finishes."""
transport = AsyncMock()
transport.write = AsyncMock()
transport.end_input = AsyncMock()

from claude_agent_sdk._internal.query import Query

q = Query(
transport=transport, is_streaming_mode=True, can_use_tool=None, hooks={}
)

async def single_message():
yield {"type": "user", "message": {"role": "user", "content": "hello"}}

await q.stream_input(single_message())
transport.end_input.assert_called_once()

@pytest.mark.asyncio
async def test_stream_input_keeps_stdin_open_with_hooks(self):
"""With hooks, stream_input blocks instead of closing stdin."""
transport = AsyncMock()
transport.write = AsyncMock()
transport.end_input = AsyncMock()

from claude_agent_sdk._internal.query import Query

hooks = {
"tool_use_start": [{"matcher": {"tool": "Bash"}, "hooks": [AsyncMock()]}]
}
q = Query(
transport=transport, is_streaming_mode=True, can_use_tool=None, hooks=hooks
)

async def single_message():
yield {"type": "user", "message": {"role": "user", "content": "hello"}}

# stream_input should block indefinitely with hooks. Cancel it after a
# short delay and verify end_input was never called.
with anyio.CancelScope() as scope:
scope.cancel()
await q.stream_input(single_message())

transport.end_input.assert_not_called()