Skip to content

Commit 8ea3a12

Browse files
dicksontsairushilpatel0
authored andcommitted
Fix test
Signed-off-by: Rushil Patel <[email protected]>
1 parent 0d23841 commit 8ea3a12

File tree

1 file changed

+52
-96
lines changed

1 file changed

+52
-96
lines changed

tests/test_streaming_client.py

Lines changed: 52 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
import sys
55
import tempfile
6-
import textwrap
76
from pathlib import Path
87
from unittest.mock import AsyncMock, patch
98

@@ -373,80 +372,6 @@ async def get_next_message():
373372
class TestQueryWithAsyncIterable:
374373
"""Test query() function with async iterable inputs."""
375374

376-
def _create_test_script(
377-
self, expected_messages=None, response=None, should_error=False
378-
):
379-
"""Create a test script that validates CLI args and stdin messages.
380-
381-
Args:
382-
expected_messages: List of expected message content strings, or None to skip validation
383-
response: Custom response to output, defaults to a success result
384-
should_error: If True, script will exit with error after reading stdin
385-
386-
Returns:
387-
Path to the test script
388-
"""
389-
if response is None:
390-
response = '{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}'
391-
392-
script_content = textwrap.dedent(
393-
"""
394-
#!/usr/bin/env python3
395-
import sys
396-
import json
397-
import time
398-
399-
# Check command line args
400-
args = sys.argv[1:]
401-
assert "--output-format" in args
402-
assert "stream-json" in args
403-
404-
# Read stdin messages
405-
stdin_messages = []
406-
stdin_closed = False
407-
try:
408-
while True:
409-
line = sys.stdin.readline()
410-
if not line:
411-
stdin_closed = True
412-
break
413-
stdin_messages.append(line.strip())
414-
except:
415-
stdin_closed = True
416-
""",
417-
)
418-
419-
if expected_messages is not None:
420-
script_content += textwrap.dedent(
421-
f"""
422-
# Verify we got the expected messages
423-
assert len(stdin_messages) == {len(expected_messages)}
424-
""",
425-
)
426-
for i, msg in enumerate(expected_messages):
427-
script_content += f'''assert '"{msg}"' in stdin_messages[{i}]\n'''
428-
429-
if should_error:
430-
script_content += textwrap.dedent(
431-
"""
432-
sys.exit(1)
433-
""",
434-
)
435-
else:
436-
script_content += textwrap.dedent(
437-
f"""
438-
# Output response
439-
print('{response}')
440-
""",
441-
)
442-
443-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
444-
test_script = f.name
445-
f.write(script_content)
446-
447-
Path(test_script).chmod(0o755)
448-
return test_script
449-
450375
def test_query_with_async_iterable(self):
451376
"""Test query with async iterable of messages."""
452377

@@ -455,32 +380,63 @@ async def message_stream():
455380
yield {"type": "user", "message": {"role": "user", "content": "First"}}
456381
yield {"type": "user", "message": {"role": "user", "content": "Second"}}
457382

458-
test_script = self._create_test_script(
459-
expected_messages=["First", "Second"]
460-
)
383+
# Create a simple test script that validates stdin and outputs a result
384+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
385+
test_script = f.name
386+
f.write("""#!/usr/bin/env python3
387+
import sys
388+
import json
389+
390+
# Read stdin messages
391+
stdin_messages = []
392+
while True:
393+
line = sys.stdin.readline()
394+
if not line:
395+
break
396+
stdin_messages.append(line.strip())
397+
398+
# Verify we got 2 messages
399+
assert len(stdin_messages) == 2
400+
assert '"First"' in stdin_messages[0]
401+
assert '"Second"' in stdin_messages[1]
402+
403+
# Output a valid result
404+
print('{"type": "result", "subtype": "success", "duration_ms": 100, "duration_api_ms": 50, "is_error": false, "num_turns": 1, "session_id": "test", "total_cost_usd": 0.001}')
405+
""")
406+
407+
Path(test_script).chmod(0o755)
461408

462409
try:
463-
# Mock _build_command to return our test script
410+
# Mock _find_cli to return python executing our test script
464411
with patch.object(
465412
SubprocessCLITransport,
466-
"_build_command",
467-
return_value=[
468-
sys.executable,
469-
test_script,
470-
"--output-format",
471-
"stream-json",
472-
"--verbose",
473-
],
413+
"_find_cli",
414+
return_value=sys.executable
474415
):
475-
# Run query with async iterable
476-
messages = []
477-
async for msg in query(prompt=message_stream()):
478-
messages.append(msg)
479-
480-
# Should get the result message
481-
assert len(messages) == 1
482-
assert isinstance(messages[0], ResultMessage)
483-
assert messages[0].subtype == "success"
416+
# Mock _build_command to add our test script as first argument
417+
original_build_command = SubprocessCLITransport._build_command
418+
419+
def mock_build_command(self):
420+
# Get original command
421+
cmd = original_build_command(self)
422+
# Replace the CLI path with python + script
423+
cmd[0] = test_script
424+
return cmd
425+
426+
with patch.object(
427+
SubprocessCLITransport,
428+
"_build_command",
429+
mock_build_command
430+
):
431+
# Run query with async iterable
432+
messages = []
433+
async for msg in query(prompt=message_stream()):
434+
messages.append(msg)
435+
436+
# Should get the result message
437+
assert len(messages) == 1
438+
assert isinstance(messages[0], ResultMessage)
439+
assert messages[0].subtype == "success"
484440
finally:
485441
# Clean up
486442
Path(test_script).unlink()

0 commit comments

Comments
 (0)