Skip to content

Commit 4175e4a

Browse files
dicksontsairushilpatel0
authored andcommitted
Fix test
Signed-off-by: Rushil Patel <[email protected]>
1 parent e7b9e8c commit 4175e4a

File tree

1 file changed

+102
-197
lines changed

1 file changed

+102
-197
lines changed

tests/test_streaming_client.py

Lines changed: 102 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Tests for ClaudeSDKClient streaming functionality and query() with async iterables."""
22

33
import asyncio
4-
from unittest.mock import AsyncMock, MagicMock, patch
4+
import sys
5+
import tempfile
6+
import textwrap
7+
from pathlib import Path
8+
from unittest.mock import AsyncMock, patch
59

610
import anyio
711
import pytest
@@ -12,11 +16,11 @@
1216
ClaudeSDKClient,
1317
CLIConnectionError,
1418
ResultMessage,
15-
SystemMessage,
1619
TextBlock,
1720
UserMessage,
1821
query,
1922
)
23+
from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport
2024

2125

2226
class TestClaudeSDKClientStreaming:
@@ -369,212 +373,113 @@ async def get_next_message():
369373
class TestQueryWithAsyncIterable:
370374
"""Test query() function with async iterable inputs."""
371375

372-
def test_query_with_async_iterable(self):
373-
"""Test query with async iterable of messages."""
374-
375-
async def _test():
376-
async def message_stream():
377-
yield {"type": "user", "message": {"role": "user", "content": "First"}}
378-
yield {"type": "user", "message": {"role": "user", "content": "Second"}}
379-
380-
with patch(
381-
"claude_code_sdk._internal.client.InternalClient"
382-
) as mock_client_class:
383-
mock_client = MagicMock()
384-
mock_client_class.return_value = mock_client
385-
386-
# Mock the async generator response
387-
async def mock_process():
388-
yield AssistantMessage(
389-
content=[TextBlock(text="Response to both messages")]
390-
)
391-
yield ResultMessage(
392-
subtype="success",
393-
duration_ms=1000,
394-
duration_api_ms=800,
395-
is_error=False,
396-
num_turns=2,
397-
session_id="test",
398-
total_cost_usd=0.002,
399-
)
400-
401-
mock_client.process_query.return_value = mock_process()
402-
403-
# Run query with async iterable
404-
messages = []
405-
async for msg in query(prompt=message_stream()):
406-
messages.append(msg)
407-
408-
assert len(messages) == 2
409-
assert isinstance(messages[0], AssistantMessage)
410-
assert isinstance(messages[1], ResultMessage)
411-
412-
# Verify process_query was called with async iterable
413-
call_kwargs = mock_client.process_query.call_args.kwargs
414-
# The prompt should be an async generator
415-
assert hasattr(call_kwargs["prompt"], "__aiter__")
416-
417-
anyio.run(_test)
418-
419-
def test_query_async_iterable_with_options(self):
420-
"""Test query with async iterable and custom options."""
421-
422-
async def _test():
423-
async def complex_stream():
424-
yield {
425-
"type": "user",
426-
"message": {"role": "user", "content": "Setup"},
427-
"parent_tool_use_id": None,
428-
"session_id": "session-1",
429-
}
430-
yield {
431-
"type": "user",
432-
"message": {"role": "user", "content": "Execute"},
433-
"parent_tool_use_id": None,
434-
"session_id": "session-1",
435-
}
436-
437-
options = ClaudeCodeOptions(
438-
cwd="/workspace",
439-
permission_mode="acceptEdits",
440-
max_turns=10,
376+
def _create_test_script(
377+
self, expected_messages=None, response=None, should_error=False
378+
):
379+
"""Create a test script that validates CLI args and stdin messages.
380+
381+
Args:
382+
expected_messages: List of expected message content strings, or None to skip validation
383+
response: Custom response to output, defaults to a success result
384+
should_error: If True, script will exit with error after reading stdin
385+
386+
Returns:
387+
Path to the test script
388+
"""
389+
if response is None:
390+
response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}'
391+
392+
script_content = textwrap.dedent("""
393+
#!/usr/bin/env python3
394+
import sys
395+
import json
396+
import time
397+
398+
# Check command line args
399+
args = sys.argv[1:]
400+
assert "--output-format" in args
401+
assert "stream-json" in args
402+
403+
# Read stdin messages
404+
stdin_messages = []
405+
stdin_closed = False
406+
try:
407+
while True:
408+
line = sys.stdin.readline()
409+
if not line:
410+
stdin_closed = True
411+
break
412+
stdin_messages.append(line.strip())
413+
except:
414+
stdin_closed = True
415+
""",
416+
)
417+
418+
if expected_messages is not None:
419+
script_content += textwrap.dedent(f"""
420+
# Verify we got the expected messages
421+
assert len(stdin_messages) == {len(expected_messages)}
422+
""",
441423
)
424+
for i, msg in enumerate(expected_messages):
425+
script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n'''
442426

443-
with patch(
444-
"claude_code_sdk._internal.client.InternalClient"
445-
) as mock_client_class:
446-
mock_client = MagicMock()
447-
mock_client_class.return_value = mock_client
448-
449-
# Mock response
450-
async def mock_process():
451-
yield AssistantMessage(content=[TextBlock(text="Done")])
452-
453-
mock_client.process_query.return_value = mock_process()
454-
455-
# Run query
456-
messages = []
457-
async for msg in query(prompt=complex_stream(), options=options):
458-
messages.append(msg)
459-
460-
# Verify options were passed
461-
call_kwargs = mock_client.process_query.call_args.kwargs
462-
assert call_kwargs["options"] is options
463-
464-
anyio.run(_test)
465-
466-
def test_query_empty_async_iterable(self):
467-
"""Test query with empty async iterable."""
468-
469-
async def _test():
470-
async def empty_stream():
471-
# Never yields anything
472-
if False:
473-
yield
474-
475-
with patch(
476-
"claude_code_sdk._internal.client.InternalClient"
477-
) as mock_client_class:
478-
mock_client = MagicMock()
479-
mock_client_class.return_value = mock_client
480-
481-
# Mock response
482-
async def mock_process():
483-
yield SystemMessage(
484-
subtype="info", data={"message": "No input provided"}
485-
)
486-
487-
mock_client.process_query.return_value = mock_process()
488-
489-
# Run query with empty stream
490-
messages = []
491-
async for msg in query(prompt=empty_stream()):
492-
messages.append(msg)
493-
494-
assert len(messages) == 1
495-
assert isinstance(messages[0], SystemMessage)
496-
497-
anyio.run(_test)
498-
499-
def test_query_async_iterable_with_delay(self):
500-
"""Test query with async iterable that has delays between yields."""
501-
502-
async def _test():
503-
async def delayed_stream():
504-
yield {"type": "user", "message": {"role": "user", "content": "Start"}}
505-
await asyncio.sleep(0.1)
506-
yield {"type": "user", "message": {"role": "user", "content": "Middle"}}
507-
await asyncio.sleep(0.1)
508-
yield {"type": "user", "message": {"role": "user", "content": "End"}}
509-
510-
with patch(
511-
"claude_code_sdk._internal.client.InternalClient"
512-
) as mock_client_class:
513-
mock_client = MagicMock()
514-
mock_client_class.return_value = mock_client
515-
516-
# Track if the stream was consumed
517-
stream_consumed = False
518-
519-
# Mock process_query to consume the input stream
520-
async def mock_process_query(prompt, options):
521-
nonlocal stream_consumed
522-
# Consume the async iterable to trigger delays
523-
items = []
524-
async for item in prompt:
525-
items.append(item)
526-
stream_consumed = True
527-
# Then yield response
528-
yield AssistantMessage(
529-
content=[TextBlock(text="Processing all messages")]
530-
)
531-
532-
mock_client.process_query = mock_process_query
533-
534-
# Time the execution
535-
import time
536-
537-
start_time = time.time()
538-
messages = []
539-
async for msg in query(prompt=delayed_stream()):
540-
messages.append(msg)
541-
elapsed = time.time() - start_time
427+
if should_error:
428+
script_content += textwrap.dedent("""
429+
sys.exit(1)
430+
""",
431+
)
432+
else:
433+
script_content += textwrap.dedent(f"""
434+
# Output response
435+
print('{response}')
436+
""",
437+
)
542438

543-
# Should have taken at least 0.2 seconds due to delays
544-
assert elapsed >= 0.2
545-
assert len(messages) == 1
546-
assert stream_consumed
439+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
440+
test_script = f.name
441+
f.write(script_content)
547442

548-
anyio.run(_test)
443+
Path(test_script).chmod(0o755)
444+
return test_script
549445

550-
def test_query_async_iterable_exception_handling(self):
551-
"""Test query handles exceptions in async iterable."""
446+
def test_query_with_async_iterable(self):
447+
"""Test query with async iterable of messages."""
552448

553449
async def _test():
554-
async def failing_stream():
450+
async def message_stream():
555451
yield {"type": "user", "message": {"role": "user", "content": "First"}}
556-
raise ValueError("Stream error")
557-
558-
with patch(
559-
"claude_code_sdk._internal.client.InternalClient"
560-
) as mock_client_class:
561-
mock_client = MagicMock()
562-
mock_client_class.return_value = mock_client
563-
564-
# The internal client should receive the failing stream
565-
# and handle the error appropriately
566-
async def mock_process():
567-
# Simulate processing until error
568-
yield AssistantMessage(content=[TextBlock(text="Error occurred")])
452+
yield {"type": "user", "message": {"role": "user", "content": "Second"}}
569453

570-
mock_client.process_query.return_value = mock_process()
454+
test_script = self._create_test_script(
455+
expected_messages=["First", "Second"]
456+
)
571457

572-
# Query should handle the error gracefully
573-
messages = []
574-
async for msg in query(prompt=failing_stream()):
575-
messages.append(msg)
458+
try:
459+
# Mock _build_command to return our test script
460+
with patch.object(
461+
SubprocessCLITransport,
462+
"_build_command",
463+
return_value=[
464+
sys.executable,
465+
test_script,
466+
"--output-format",
467+
"stream-json",
468+
"--verbose",
469+
],
470+
):
471+
# Run query with async iterable
472+
messages = []
473+
async for msg in query(prompt=message_stream()):
474+
messages.append(msg)
576475

577-
assert len(messages) == 1
476+
# Should get the result message
477+
assert len(messages) == 1
478+
assert isinstance(messages[0], ResultMessage)
479+
assert messages[0].subtype == "success"
480+
finally:
481+
# Clean up
482+
Path(test_script).unlink()
578483

579484
anyio.run(_test)
580485

0 commit comments

Comments
 (0)