33import asyncio
44import sys
55import tempfile
6- import textwrap
76from pathlib import Path
87from unittest .mock import AsyncMock , patch
98
@@ -373,80 +372,6 @@ async def get_next_message():
373372class TestQueryWithAsyncIterable :
374373 """Test query() function with async iterable inputs."""
375374
376- def _create_test_script (
377- self , expected_messages = None , response = None , should_error = False
378- ):
379- """Create a test script that validates CLI args and stdin messages.
380-
381- Args:
382- expected_messages: List of expected message content strings, or None to skip validation
383- response: Custom response to output, defaults to a success result
384- should_error: If True, script will exit with error after reading stdin
385-
386- Returns:
387- Path to the test script
388- """
389- if response is None :
390- response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}'
391-
392- script_content = textwrap .dedent (
393- """
394- #!/usr/bin/env python3
395- import sys
396- import json
397- import time
398-
399- # Check command line args
400- args = sys.argv[1:]
401- assert "--output-format" in args
402- assert "stream-json" in args
403-
404- # Read stdin messages
405- stdin_messages = []
406- stdin_closed = False
407- try:
408- while True:
409- line = sys.stdin.readline()
410- if not line:
411- stdin_closed = True
412- break
413- stdin_messages.append(line.strip())
414- except:
415- stdin_closed = True
416- """ ,
417- )
418-
419- if expected_messages is not None :
420- script_content += textwrap .dedent (
421- f"""
422- # Verify we got the expected messages
423- assert len(stdin_messages) == { len (expected_messages )}
424- """ ,
425- )
426- for i , msg in enumerate (expected_messages ):
427- script_content += f'''assert '"{ msg } "' in stdin_messages[{ i } ]\n '''
428-
429- if should_error :
430- script_content += textwrap .dedent (
431- """
432- sys.exit(1)
433- """ ,
434- )
435- else :
436- script_content += textwrap .dedent (
437- f"""
438- # Output response
439- print('{ response } ')
440- """ ,
441- )
442-
443- with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" , delete = False ) as f :
444- test_script = f .name
445- f .write (script_content )
446-
447- Path (test_script ).chmod (0o755 )
448- return test_script
449-
450375 def test_query_with_async_iterable (self ):
451376 """Test query with async iterable of messages."""
452377
@@ -455,32 +380,63 @@ async def message_stream():
455380 yield {"type" : "user" , "message" : {"role" : "user" , "content" : "First" }}
456381 yield {"type" : "user" , "message" : {"role" : "user" , "content" : "Second" }}
457382
458- test_script = self ._create_test_script (
459- expected_messages = ["First" , "Second" ]
460- )
383+ # Create a simple test script that validates stdin and outputs a result
384+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" , delete = False ) as f :
385+ test_script = f .name
386+ f .write ("""#!/usr/bin/env python3
387+ import sys
388+ import json
389+
390+ # Read stdin messages
391+ stdin_messages = []
392+ while True:
393+ line = sys.stdin.readline()
394+ if not line:
395+ break
396+ stdin_messages.append(line.strip())
397+
398+ # Verify we got 2 messages
399+ assert len(stdin_messages) == 2
400+ assert '"First"' in stdin_messages[0]
401+ assert '"Second"' in stdin_messages[1]
402+
403+ # Output a valid result
404+ print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}')
405+ """ )
406+
407+ Path (test_script ).chmod (0o755 )
461408
462409 try :
463- # Mock _build_command to return our test script
410+ # Mock _find_cli to return python executing our test script
464411 with patch .object (
465412 SubprocessCLITransport ,
466- "_build_command" ,
467- return_value = [
468- sys .executable ,
469- test_script ,
470- "--output-format" ,
471- "stream-json" ,
472- "--verbose" ,
473- ],
413+ "_find_cli" ,
414+ return_value = sys .executable
474415 ):
475- # Run query with async iterable
476- messages = []
477- async for msg in query (prompt = message_stream ()):
478- messages .append (msg )
479-
480- # Should get the result message
481- assert len (messages ) == 1
482- assert isinstance (messages [0 ], ResultMessage )
483- assert messages [0 ].subtype == "success"
416+ # Mock _build_command to add our test script as first argument
417+ original_build_command = SubprocessCLITransport ._build_command
418+
419+ def mock_build_command (self ):
420+ # Get original command
421+ cmd = original_build_command (self )
422+ # Replace the CLI path with python + script
423+ cmd [0 ] = test_script
424+ return cmd
425+
426+ with patch .object (
427+ SubprocessCLITransport ,
428+ "_build_command" ,
429+ mock_build_command
430+ ):
431+ # Run query with async iterable
432+ messages = []
433+ async for msg in query (prompt = message_stream ()):
434+ messages .append (msg )
435+
436+ # Should get the result message
437+ assert len (messages ) == 1
438+ assert isinstance (messages [0 ], ResultMessage )
439+ assert messages [0 ].subtype == "success"
484440 finally :
485441 # Clean up
486442 Path (test_script ).unlink ()
0 commit comments