|
1 | 1 | """Tests for ClaudeSDKClient streaming functionality and query() with async iterables.""" |
2 | 2 |
|
3 | 3 | import asyncio |
4 | | -from unittest.mock import AsyncMock, MagicMock, patch |
| 4 | +import sys |
| 5 | +import tempfile |
| 6 | +import textwrap |
| 7 | +from pathlib import Path |
| 8 | +from unittest.mock import AsyncMock, patch |
5 | 9 |
|
6 | 10 | import anyio |
7 | 11 | import pytest |
|
12 | 16 | ClaudeSDKClient, |
13 | 17 | CLIConnectionError, |
14 | 18 | ResultMessage, |
15 | | - SystemMessage, |
16 | 19 | TextBlock, |
17 | 20 | UserMessage, |
18 | 21 | query, |
19 | 22 | ) |
| 23 | +from claude_code_sdk._internal.transport.subprocess_cli import SubprocessCLITransport |
20 | 24 |
|
21 | 25 |
|
22 | 26 | class TestClaudeSDKClientStreaming: |
@@ -369,212 +373,113 @@ async def get_next_message(): |
369 | 373 | class TestQueryWithAsyncIterable: |
370 | 374 | """Test query() function with async iterable inputs.""" |
371 | 375 |
|
372 | | - def test_query_with_async_iterable(self): |
373 | | - """Test query with async iterable of messages.""" |
374 | | - |
375 | | - async def _test(): |
376 | | - async def message_stream(): |
377 | | - yield {"type": "user", "message": {"role": "user", "content": "First"}} |
378 | | - yield {"type": "user", "message": {"role": "user", "content": "Second"}} |
379 | | - |
380 | | - with patch( |
381 | | - "claude_code_sdk._internal.client.InternalClient" |
382 | | - ) as mock_client_class: |
383 | | - mock_client = MagicMock() |
384 | | - mock_client_class.return_value = mock_client |
385 | | - |
386 | | - # Mock the async generator response |
387 | | - async def mock_process(): |
388 | | - yield AssistantMessage( |
389 | | - content=[TextBlock(text="Response to both messages")] |
390 | | - ) |
391 | | - yield ResultMessage( |
392 | | - subtype="success", |
393 | | - duration_ms=1000, |
394 | | - duration_api_ms=800, |
395 | | - is_error=False, |
396 | | - num_turns=2, |
397 | | - session_id="test", |
398 | | - total_cost_usd=0.002, |
399 | | - ) |
400 | | - |
401 | | - mock_client.process_query.return_value = mock_process() |
402 | | - |
403 | | - # Run query with async iterable |
404 | | - messages = [] |
405 | | - async for msg in query(prompt=message_stream()): |
406 | | - messages.append(msg) |
407 | | - |
408 | | - assert len(messages) == 2 |
409 | | - assert isinstance(messages[0], AssistantMessage) |
410 | | - assert isinstance(messages[1], ResultMessage) |
411 | | - |
412 | | - # Verify process_query was called with async iterable |
413 | | - call_kwargs = mock_client.process_query.call_args.kwargs |
414 | | - # The prompt should be an async generator |
415 | | - assert hasattr(call_kwargs["prompt"], "__aiter__") |
416 | | - |
417 | | - anyio.run(_test) |
418 | | - |
419 | | - def test_query_async_iterable_with_options(self): |
420 | | - """Test query with async iterable and custom options.""" |
421 | | - |
422 | | - async def _test(): |
423 | | - async def complex_stream(): |
424 | | - yield { |
425 | | - "type": "user", |
426 | | - "message": {"role": "user", "content": "Setup"}, |
427 | | - "parent_tool_use_id": None, |
428 | | - "session_id": "session-1", |
429 | | - } |
430 | | - yield { |
431 | | - "type": "user", |
432 | | - "message": {"role": "user", "content": "Execute"}, |
433 | | - "parent_tool_use_id": None, |
434 | | - "session_id": "session-1", |
435 | | - } |
436 | | - |
437 | | - options = ClaudeCodeOptions( |
438 | | - cwd="/workspace", |
439 | | - permission_mode="acceptEdits", |
440 | | - max_turns=10, |
| 376 | + def _create_test_script( |
| 377 | + self, expected_messages=None, response=None, should_error=False |
| 378 | + ): |
| 379 | + """Create a test script that validates CLI args and stdin messages. |
| 380 | +
|
| 381 | + Args: |
| 382 | + expected_messages: List of expected message content strings, or None to skip validation |
| 383 | + response: Custom response to output, defaults to a success result |
| 384 | + should_error: If True, script will exit with error after reading stdin |
| 385 | +
|
| 386 | + Returns: |
| 387 | + Path to the test script |
| 388 | + """ |
| 389 | + if response is None: |
| 390 | + response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}' |
| 391 | + |
| 392 | + script_content = textwrap.dedent(""" |
| 393 | + #!/usr/bin/env python3 |
| 394 | + import sys |
| 395 | + import json |
| 396 | + import time |
| 397 | +
|
| 398 | + # Check command line args |
| 399 | + args = sys.argv[1:] |
| 400 | + assert "--output-format" in args |
| 401 | + assert "stream-json" in args |
| 402 | +
|
| 403 | + # Read stdin messages |
| 404 | + stdin_messages = [] |
| 405 | + stdin_closed = False |
| 406 | + try: |
| 407 | + while True: |
| 408 | + line = sys.stdin.readline() |
| 409 | + if not line: |
| 410 | + stdin_closed = True |
| 411 | + break |
| 412 | + stdin_messages.append(line.strip()) |
| 413 | + except: |
| 414 | + stdin_closed = True |
| 415 | + """, |
| 416 | + ) |
| 417 | + |
| 418 | + if expected_messages is not None: |
| 419 | + script_content += textwrap.dedent(f""" |
| 420 | + # Verify we got the expected messages |
| 421 | + assert len(stdin_messages) == {len(expected_messages)} |
| 422 | + """, |
441 | 423 | ) |
| 424 | + for i, msg in enumerate(expected_messages): |
| 425 | + script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n''' |
442 | 426 |
|
443 | | - with patch( |
444 | | - "claude_code_sdk._internal.client.InternalClient" |
445 | | - ) as mock_client_class: |
446 | | - mock_client = MagicMock() |
447 | | - mock_client_class.return_value = mock_client |
448 | | - |
449 | | - # Mock response |
450 | | - async def mock_process(): |
451 | | - yield AssistantMessage(content=[TextBlock(text="Done")]) |
452 | | - |
453 | | - mock_client.process_query.return_value = mock_process() |
454 | | - |
455 | | - # Run query |
456 | | - messages = [] |
457 | | - async for msg in query(prompt=complex_stream(), options=options): |
458 | | - messages.append(msg) |
459 | | - |
460 | | - # Verify options were passed |
461 | | - call_kwargs = mock_client.process_query.call_args.kwargs |
462 | | - assert call_kwargs["options"] is options |
463 | | - |
464 | | - anyio.run(_test) |
465 | | - |
466 | | - def test_query_empty_async_iterable(self): |
467 | | - """Test query with empty async iterable.""" |
468 | | - |
469 | | - async def _test(): |
470 | | - async def empty_stream(): |
471 | | - # Never yields anything |
472 | | - if False: |
473 | | - yield |
474 | | - |
475 | | - with patch( |
476 | | - "claude_code_sdk._internal.client.InternalClient" |
477 | | - ) as mock_client_class: |
478 | | - mock_client = MagicMock() |
479 | | - mock_client_class.return_value = mock_client |
480 | | - |
481 | | - # Mock response |
482 | | - async def mock_process(): |
483 | | - yield SystemMessage( |
484 | | - subtype="info", data={"message": "No input provided"} |
485 | | - ) |
486 | | - |
487 | | - mock_client.process_query.return_value = mock_process() |
488 | | - |
489 | | - # Run query with empty stream |
490 | | - messages = [] |
491 | | - async for msg in query(prompt=empty_stream()): |
492 | | - messages.append(msg) |
493 | | - |
494 | | - assert len(messages) == 1 |
495 | | - assert isinstance(messages[0], SystemMessage) |
496 | | - |
497 | | - anyio.run(_test) |
498 | | - |
499 | | - def test_query_async_iterable_with_delay(self): |
500 | | - """Test query with async iterable that has delays between yields.""" |
501 | | - |
502 | | - async def _test(): |
503 | | - async def delayed_stream(): |
504 | | - yield {"type": "user", "message": {"role": "user", "content": "Start"}} |
505 | | - await asyncio.sleep(0.1) |
506 | | - yield {"type": "user", "message": {"role": "user", "content": "Middle"}} |
507 | | - await asyncio.sleep(0.1) |
508 | | - yield {"type": "user", "message": {"role": "user", "content": "End"}} |
509 | | - |
510 | | - with patch( |
511 | | - "claude_code_sdk._internal.client.InternalClient" |
512 | | - ) as mock_client_class: |
513 | | - mock_client = MagicMock() |
514 | | - mock_client_class.return_value = mock_client |
515 | | - |
516 | | - # Track if the stream was consumed |
517 | | - stream_consumed = False |
518 | | - |
519 | | - # Mock process_query to consume the input stream |
520 | | - async def mock_process_query(prompt, options): |
521 | | - nonlocal stream_consumed |
522 | | - # Consume the async iterable to trigger delays |
523 | | - items = [] |
524 | | - async for item in prompt: |
525 | | - items.append(item) |
526 | | - stream_consumed = True |
527 | | - # Then yield response |
528 | | - yield AssistantMessage( |
529 | | - content=[TextBlock(text="Processing all messages")] |
530 | | - ) |
531 | | - |
532 | | - mock_client.process_query = mock_process_query |
533 | | - |
534 | | - # Time the execution |
535 | | - import time |
536 | | - |
537 | | - start_time = time.time() |
538 | | - messages = [] |
539 | | - async for msg in query(prompt=delayed_stream()): |
540 | | - messages.append(msg) |
541 | | - elapsed = time.time() - start_time |
| 427 | + if should_error: |
| 428 | + script_content += textwrap.dedent(""" |
| 429 | + sys.exit(1) |
| 430 | + """, |
| 431 | + ) |
| 432 | + else: |
| 433 | + script_content += textwrap.dedent(f""" |
| 434 | + # Output response |
| 435 | + print('{response}') |
| 436 | + """, |
| 437 | + ) |
542 | 438 |
|
543 | | - # Should have taken at least 0.2 seconds due to delays |
544 | | - assert elapsed >= 0.2 |
545 | | - assert len(messages) == 1 |
546 | | - assert stream_consumed |
| 439 | + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: |
| 440 | + test_script = f.name |
| 441 | + f.write(script_content) |
547 | 442 |
|
548 | | - anyio.run(_test) |
| 443 | + Path(test_script).chmod(0o755) |
| 444 | + return test_script |
549 | 445 |
|
550 | | - def test_query_async_iterable_exception_handling(self): |
551 | | - """Test query handles exceptions in async iterable.""" |
| 446 | + def test_query_with_async_iterable(self): |
| 447 | + """Test query with async iterable of messages.""" |
552 | 448 |
|
553 | 449 | async def _test(): |
554 | | - async def failing_stream(): |
| 450 | + async def message_stream(): |
555 | 451 | yield {"type": "user", "message": {"role": "user", "content": "First"}} |
556 | | - raise ValueError("Stream error") |
557 | | - |
558 | | - with patch( |
559 | | - "claude_code_sdk._internal.client.InternalClient" |
560 | | - ) as mock_client_class: |
561 | | - mock_client = MagicMock() |
562 | | - mock_client_class.return_value = mock_client |
563 | | - |
564 | | - # The internal client should receive the failing stream |
565 | | - # and handle the error appropriately |
566 | | - async def mock_process(): |
567 | | - # Simulate processing until error |
568 | | - yield AssistantMessage(content=[TextBlock(text="Error occurred")]) |
| 452 | + yield {"type": "user", "message": {"role": "user", "content": "Second"}} |
569 | 453 |
|
570 | | - mock_client.process_query.return_value = mock_process() |
| 454 | + test_script = self._create_test_script( |
| 455 | + expected_messages=["First", "Second"] |
| 456 | + ) |
571 | 457 |
|
572 | | - # Query should handle the error gracefully |
573 | | - messages = [] |
574 | | - async for msg in query(prompt=failing_stream()): |
575 | | - messages.append(msg) |
| 458 | + try: |
| 459 | + # Mock _build_command to return our test script |
| 460 | + with patch.object( |
| 461 | + SubprocessCLITransport, |
| 462 | + "_build_command", |
| 463 | + return_value=[ |
| 464 | + sys.executable, |
| 465 | + test_script, |
| 466 | + "--output-format", |
| 467 | + "stream-json", |
| 468 | + "--verbose", |
| 469 | + ], |
| 470 | + ): |
| 471 | + # Run query with async iterable |
| 472 | + messages = [] |
| 473 | + async for msg in query(prompt=message_stream()): |
| 474 | + messages.append(msg) |
576 | 475 |
|
577 | | - assert len(messages) == 1 |
| 476 | + # Should get the result message |
| 477 | + assert len(messages) == 1 |
| 478 | + assert isinstance(messages[0], ResultMessage) |
| 479 | + assert messages[0].subtype == "success" |
| 480 | + finally: |
| 481 | + # Clean up |
| 482 | + Path(test_script).unlink() |
578 | 483 |
|
579 | 484 | anyio.run(_test) |
580 | 485 |
|
|
0 commit comments