|
4 | 4 |
|
5 | 5 | import pytest
|
6 | 6 | import asyncio
|
| 7 | +from types import SimpleNamespace |
7 | 8 | from unittest.mock import Mock, MagicMock, AsyncMock, patch
|
8 | 9 |
|
9 | 10 |
|
@@ -200,6 +201,55 @@ async def run_async(self, *args, **kwargs):
|
200 | 201 | # Confirm branch selection
|
201 | 202 | assert len(streaming_calls) == 1
|
202 | 203 | assert lro_calls == []
|
| 204 | + |
| 205 | + async def test_partial_final_chunk_uses_streaming_translation(self, adk_agent, sample_input): |
| 206 | + """Ensure partial chunks marked as final still use streaming translation.""" |
| 207 | + |
| 208 | + translate_calls = 0 |
| 209 | + lro_calls = 0 |
| 210 | + |
| 211 | + async def fake_translate(self, adk_event, thread_id, run_id): |
| 212 | + nonlocal translate_calls |
| 213 | + translate_calls += 1 |
| 214 | + yield TextMessageChunkEvent( |
| 215 | + type=EventType.TEXT_MESSAGE_CHUNK, |
| 216 | + message_id=adk_event.id, |
| 217 | + delta="chunk" |
| 218 | + ) |
| 219 | + |
| 220 | + async def fake_translate_lro(self, adk_event): |
| 221 | + nonlocal lro_calls |
| 222 | + lro_calls += 1 |
| 223 | + if False: |
| 224 | + yield # pragma: no cover - keeps this an async generator |
| 225 | + |
| 226 | + adk_event = SimpleNamespace( |
| 227 | + id="event-final-chunk", |
| 228 | + author="assistant", |
| 229 | + content=SimpleNamespace(parts=[SimpleNamespace(text="hello")]), |
| 230 | + partial=True, |
| 231 | + turn_complete=True, |
| 232 | + usage_metadata={"tokens": 1}, |
| 233 | + finish_reason="STOP", |
| 234 | + actions=None, |
| 235 | + custom_data=None, |
| 236 | + get_function_calls=lambda: [], |
| 237 | + get_function_responses=lambda: [], |
| 238 | + is_final_response=lambda: True |
| 239 | + ) |
| 240 | + |
| 241 | + class FakeRunner: |
| 242 | + async def run_async(self, *args, **kwargs): |
| 243 | + yield adk_event |
| 244 | + |
| 245 | + with patch("ag_ui_adk.adk_agent.EventTranslator.translate", new=fake_translate), \ |
| 246 | + patch("ag_ui_adk.adk_agent.EventTranslator.translate_lro_function_calls", new=fake_translate_lro), \ |
| 247 | + patch.object(adk_agent, "_create_runner", return_value=FakeRunner()): |
| 248 | + events = [event async for event in adk_agent.run(sample_input)] |
| 249 | + |
| 250 | + assert any(isinstance(event, TextMessageChunkEvent) for event in events) |
| 251 | + assert translate_calls == 1 |
| 252 | + assert lro_calls == 0 |
203 | 253 |
|
204 | 254 | @pytest.mark.asyncio
|
205 | 255 | async def test_session_management(self, adk_agent):
|
|
0 commit comments