Skip to content

Commit e7b9e8c

Browse files
dicksontsairushilpatel0
authored andcommitted
Close stdin for query()
1 parent f8cab19 commit e7b9e8c

File tree

3 files changed

+55
-14
lines changed

3 files changed

+55
-14
lines changed

src/claude_code_sdk/_internal/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@ async def process_query(
3939
if transport is not None:
4040
chosen_transport = transport
4141
else:
42-
chosen_transport = SubprocessCLITransport(prompt, options)
42+
chosen_transport = SubprocessCLITransport(
43+
prompt=prompt,
44+
options=options,
45+
close_stdin_after_prompt=True
46+
)
4347

4448
try:
4549
# Configure the transport with prompt and options

src/claude_code_sdk/_internal/transport/subprocess_cli.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
prompt: str | AsyncIterable[dict[str, Any]],
3333
options: ClaudeCodeOptions,
3434
cli_path: str | Path | None = None,
35+
close_stdin_after_prompt: bool = False,
3536
):
3637
self._prompt = prompt
3738
self._is_streaming = not isinstance(prompt, str)
@@ -44,6 +45,7 @@ def __init__(
4445
self._stdin_stream: TextSendStream | None = None
4546
self._pending_control_responses: dict[str, Any] = {}
4647
self._request_counter = 0
48+
self._close_stdin_after_prompt = close_stdin_after_prompt
4749

4850
def configure(self, prompt: str, options: ClaudeCodeOptions) -> None:
4951
"""Configure transport with prompt and options."""
@@ -238,8 +240,11 @@ async def _stream_to_stdin(self) -> None:
238240
break
239241
await self._stdin_stream.send(json.dumps(message) + "\n")
240242

241-
# Don't close stdin - keep it open for send_request
242-
# Users can explicitly call disconnect() when done
243+
# Close stdin after prompt if requested (e.g., for query() one-shot mode)
244+
if self._close_stdin_after_prompt and self._stdin_stream:
245+
await self._stdin_stream.aclose()
246+
self._stdin_stream = None
247+
# Otherwise keep stdin open for send_request (ClaudeSDKClient interactive mode)
243248
except Exception as e:
244249
logger.debug(f"Error streaming to stdin: {e}")
245250
if self._stdin_stream:

tests/test_streaming_client.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class TestClaudeSDKClientStreaming:
2424

2525
def test_auto_connect_with_context_manager(self):
2626
"""Test automatic connection when using context manager."""
27+
2728
async def _test():
2829
with patch(
2930
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -43,6 +44,7 @@ async def _test():
4344

4445
def test_manual_connect_disconnect(self):
4546
"""Test manual connect and disconnect."""
47+
4648
async def _test():
4749
with patch(
4850
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -66,6 +68,7 @@ async def _test():
6668

6769
def test_connect_with_string_prompt(self):
6870
"""Test connecting with a string prompt."""
71+
6972
async def _test():
7073
with patch(
7174
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -84,6 +87,7 @@ async def _test():
8487

8588
def test_connect_with_async_iterable(self):
8689
"""Test connecting with an async iterable."""
90+
8791
async def _test():
8892
with patch(
8993
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -93,7 +97,10 @@ async def _test():
9397

9498
async def message_stream():
9599
yield {"type": "user", "message": {"role": "user", "content": "Hi"}}
96-
yield {"type": "user", "message": {"role": "user", "content": "Bye"}}
100+
yield {
101+
"type": "user",
102+
"message": {"role": "user", "content": "Bye"},
103+
}
97104

98105
client = ClaudeSDKClient()
99106
stream = message_stream()
@@ -108,6 +115,7 @@ async def message_stream():
108115

109116
def test_send_message(self):
110117
"""Test sending a message."""
118+
111119
async def _test():
112120
with patch(
113121
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -131,6 +139,7 @@ async def _test():
131139

132140
def test_send_message_with_session_id(self):
133141
"""Test sending a message with custom session ID."""
142+
134143
async def _test():
135144
with patch(
136145
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -150,6 +159,7 @@ async def _test():
150159

151160
def test_send_message_not_connected(self):
152161
"""Test sending message when not connected raises error."""
162+
153163
async def _test():
154164
client = ClaudeSDKClient()
155165
with pytest.raises(CLIConnectionError, match="Not connected"):
@@ -159,6 +169,7 @@ async def _test():
159169

160170
def test_receive_messages(self):
161171
"""Test receiving messages."""
172+
162173
async def _test():
163174
with patch(
164175
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -200,6 +211,7 @@ async def mock_receive():
200211

201212
def test_receive_response(self):
202213
"""Test receive_response stops at ResultMessage."""
214+
203215
async def _test():
204216
with patch(
205217
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -231,7 +243,9 @@ async def mock_receive():
231243
"type": "assistant",
232244
"message": {
233245
"role": "assistant",
234-
"content": [{"type": "text", "text": "Should not see this"}],
246+
"content": [
247+
{"type": "text", "text": "Should not see this"}
248+
],
235249
},
236250
}
237251

@@ -251,6 +265,7 @@ async def mock_receive():
251265

252266
def test_interrupt(self):
253267
"""Test interrupt functionality."""
268+
254269
async def _test():
255270
with patch(
256271
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -266,6 +281,7 @@ async def _test():
266281

267282
def test_interrupt_not_connected(self):
268283
"""Test interrupt when not connected raises error."""
284+
269285
async def _test():
270286
client = ClaudeSDKClient()
271287
with pytest.raises(CLIConnectionError, match="Not connected"):
@@ -275,6 +291,7 @@ async def _test():
275291

276292
def test_client_with_options(self):
277293
"""Test client initialization with options."""
294+
278295
async def _test():
279296
options = ClaudeCodeOptions(
280297
cwd="/custom/path",
@@ -299,6 +316,7 @@ async def _test():
299316

300317
def test_concurrent_send_receive(self):
301318
"""Test concurrent sending and receiving messages."""
319+
302320
async def _test():
303321
with patch(
304322
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -334,7 +352,7 @@ async def mock_receive():
334352
# Helper to get next message
335353
async def get_next_message():
336354
return await client.receive_response().__anext__()
337-
355+
338356
# Start receiving in background
339357
receive_task = asyncio.create_task(get_next_message())
340358

@@ -353,13 +371,14 @@ class TestQueryWithAsyncIterable:
353371

354372
def test_query_with_async_iterable(self):
355373
"""Test query with async iterable of messages."""
374+
356375
async def _test():
357376
async def message_stream():
358377
yield {"type": "user", "message": {"role": "user", "content": "First"}}
359378
yield {"type": "user", "message": {"role": "user", "content": "Second"}}
360379

361380
with patch(
362-
"claude_code_sdk.query.InternalClient"
381+
"claude_code_sdk._internal.client.InternalClient"
363382
) as mock_client_class:
364383
mock_client = MagicMock()
365384
mock_client_class.return_value = mock_client
@@ -399,6 +418,7 @@ async def mock_process():
399418

400419
def test_query_async_iterable_with_options(self):
401420
"""Test query with async iterable and custom options."""
421+
402422
async def _test():
403423
async def complex_stream():
404424
yield {
@@ -421,7 +441,7 @@ async def complex_stream():
421441
)
422442

423443
with patch(
424-
"claude_code_sdk.query.InternalClient"
444+
"claude_code_sdk._internal.client.InternalClient"
425445
) as mock_client_class:
426446
mock_client = MagicMock()
427447
mock_client_class.return_value = mock_client
@@ -445,23 +465,23 @@ async def mock_process():
445465

446466
def test_query_empty_async_iterable(self):
447467
"""Test query with empty async iterable."""
468+
448469
async def _test():
449470
async def empty_stream():
450471
# Never yields anything
451472
if False:
452473
yield
453474

454475
with patch(
455-
"claude_code_sdk.query.InternalClient"
476+
"claude_code_sdk._internal.client.InternalClient"
456477
) as mock_client_class:
457478
mock_client = MagicMock()
458479
mock_client_class.return_value = mock_client
459480

460481
# Mock response
461482
async def mock_process():
462483
yield SystemMessage(
463-
subtype="info",
464-
data={"message": "No input provided"}
484+
subtype="info", data={"message": "No input provided"}
465485
)
466486

467487
mock_client.process_query.return_value = mock_process()
@@ -478,6 +498,7 @@ async def mock_process():
478498

479499
def test_query_async_iterable_with_delay(self):
480500
"""Test query with async iterable that has delays between yields."""
501+
481502
async def _test():
482503
async def delayed_stream():
483504
yield {"type": "user", "message": {"role": "user", "content": "Start"}}
@@ -487,7 +508,7 @@ async def delayed_stream():
487508
yield {"type": "user", "message": {"role": "user", "content": "End"}}
488509

489510
with patch(
490-
"claude_code_sdk.query.InternalClient"
511+
"claude_code_sdk._internal.client.InternalClient"
491512
) as mock_client_class:
492513
mock_client = MagicMock()
493514
mock_client_class.return_value = mock_client
@@ -512,6 +533,7 @@ async def mock_process_query(prompt, options):
512533

513534
# Time the execution
514535
import time
536+
515537
start_time = time.time()
516538
messages = []
517539
async for msg in query(prompt=delayed_stream()):
@@ -527,13 +549,14 @@ async def mock_process_query(prompt, options):
527549

528550
def test_query_async_iterable_exception_handling(self):
529551
"""Test query handles exceptions in async iterable."""
552+
530553
async def _test():
531554
async def failing_stream():
532555
yield {"type": "user", "message": {"role": "user", "content": "First"}}
533556
raise ValueError("Stream error")
534557

535558
with patch(
536-
"claude_code_sdk.query.InternalClient"
559+
"claude_code_sdk._internal.client.InternalClient"
537560
) as mock_client_class:
538561
mock_client = MagicMock()
539562
mock_client_class.return_value = mock_client
@@ -561,6 +584,7 @@ class TestClaudeSDKClientEdgeCases:
561584

562585
def test_receive_messages_not_connected(self):
563586
"""Test receiving messages when not connected."""
587+
564588
async def _test():
565589
client = ClaudeSDKClient()
566590
with pytest.raises(CLIConnectionError, match="Not connected"):
@@ -571,6 +595,7 @@ async def _test():
571595

572596
def test_receive_response_not_connected(self):
573597
"""Test receive_response when not connected."""
598+
574599
async def _test():
575600
client = ClaudeSDKClient()
576601
with pytest.raises(CLIConnectionError, match="Not connected"):
@@ -581,6 +606,7 @@ async def _test():
581606

582607
def test_double_connect(self):
583608
"""Test connecting twice."""
609+
584610
async def _test():
585611
with patch(
586612
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -600,6 +626,7 @@ async def _test():
600626

601627
def test_disconnect_without_connect(self):
602628
"""Test disconnecting without connecting first."""
629+
603630
async def _test():
604631
client = ClaudeSDKClient()
605632
# Should not raise error
@@ -609,6 +636,7 @@ async def _test():
609636

610637
def test_context_manager_with_exception(self):
611638
"""Test context manager cleans up on exception."""
639+
612640
async def _test():
613641
with patch(
614642
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -627,6 +655,7 @@ async def _test():
627655

628656
def test_receive_response_list_comprehension(self):
629657
"""Test collecting messages with list comprehension as shown in examples."""
658+
630659
async def _test():
631660
with patch(
632661
"claude_code_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
@@ -668,7 +697,10 @@ async def mock_receive():
668697
messages = [msg async for msg in client.receive_response()]
669698

670699
assert len(messages) == 3
671-
assert all(isinstance(msg, AssistantMessage | ResultMessage) for msg in messages)
700+
assert all(
701+
isinstance(msg, AssistantMessage | ResultMessage)
702+
for msg in messages
703+
)
672704
assert isinstance(messages[-1], ResultMessage)
673705

674706
anyio.run(_test)

0 commit comments

Comments
 (0)