|
4 | 4 |
|
5 | 5 | import pytest
|
6 | 6 | import asyncio
|
| 7 | +from types import SimpleNamespace |
7 | 8 | from unittest.mock import Mock, MagicMock, AsyncMock, patch
|
8 | 9 |
|
9 | 10 |
|
@@ -137,6 +138,56 @@ async def mock_run_async(*args, **kwargs):
|
137 | 138 | assert events[0].type == EventType.RUN_STARTED
|
138 | 139 | assert events[-1].type == EventType.RUN_FINISHED
|
139 | 140 |
|
| 141 | + @pytest.mark.asyncio |
| 142 | + async def test_partial_final_chunk_uses_streaming_translation(self, adk_agent, sample_input): |
| 143 | + """Ensure partial chunks marked as final still use streaming translation.""" |
| 144 | + |
| 145 | + translate_calls = 0 |
| 146 | + lro_calls = 0 |
| 147 | + |
| 148 | + async def fake_translate(self, adk_event, thread_id, run_id): |
| 149 | + nonlocal translate_calls |
| 150 | + translate_calls += 1 |
| 151 | + yield TextMessageContentEvent( |
| 152 | + type=EventType.TEXT_MESSAGE_CONTENT, |
| 153 | + message_id=adk_event.id, |
| 154 | + delta="chunk" |
| 155 | + ) |
| 156 | + |
| 157 | + async def fake_translate_lro(self, adk_event): |
| 158 | + nonlocal lro_calls |
| 159 | + lro_calls += 1 |
| 160 | + if False: |
| 161 | + yield # pragma: no cover - keeps this an async generator |
| 162 | + |
| 163 | + adk_event = SimpleNamespace( |
| 164 | + id="event-final-chunk", |
| 165 | + author="assistant", |
| 166 | + content=SimpleNamespace(parts=[SimpleNamespace(text="hello")]), |
| 167 | + partial=True, |
| 168 | + turn_complete=True, |
| 169 | + usage_metadata={"tokens": 1}, |
| 170 | + finish_reason="STOP", |
| 171 | + actions=None, |
| 172 | + custom_data=None, |
| 173 | + get_function_calls=lambda: [], |
| 174 | + get_function_responses=lambda: [], |
| 175 | + is_final_response=lambda: True |
| 176 | + ) |
| 177 | + |
| 178 | + class FakeRunner: |
| 179 | + async def run_async(self, *args, **kwargs): |
| 180 | + yield adk_event |
| 181 | + |
| 182 | + with patch("ag_ui_adk.adk_agent.EventTranslator.translate", new=fake_translate), \ |
| 183 | + patch("ag_ui_adk.adk_agent.EventTranslator.translate_lro_function_calls", new=fake_translate_lro), \ |
| 184 | + patch.object(adk_agent, "_create_runner", return_value=FakeRunner()): |
| 185 | + events = [event async for event in adk_agent.run(sample_input)] |
| 186 | + |
| 187 | + assert any(isinstance(event, TextMessageContentEvent) for event in events) |
| 188 | + assert translate_calls == 1 |
| 189 | + assert lro_calls == 0 |
| 190 | + |
140 | 191 | @pytest.mark.asyncio
|
141 | 192 | async def test_session_management(self, adk_agent):
|
142 | 193 | """Test session lifecycle management."""
|
|
0 commit comments