|
8 | 8 | from unittest.mock import Mock, MagicMock, AsyncMock, patch
|
9 | 9 |
|
10 | 10 |
|
11 |
| -from ag_ui_adk import ADKAgent, SessionManager |
| 11 | +from ag_ui_adk import ADKAgent, SessionManager, EventTranslator |
12 | 12 | from ag_ui.core import (
|
13 | 13 | RunAgentInput, EventType, UserMessage, Context,
|
14 | 14 | RunStartedEvent, RunFinishedEvent, TextMessageChunkEvent, SystemMessage
|
@@ -139,6 +139,69 @@ async def mock_run_async(*args, **kwargs):
|
139 | 139 | assert events[-1].type == EventType.RUN_FINISHED
|
140 | 140 |
|
141 | 141 | @pytest.mark.asyncio
|
| 142 | + async def test_turn_complete_falls_back_to_streaming_translator( |
| 143 | + self, |
| 144 | + adk_agent, |
| 145 | + sample_input, |
| 146 | + ): |
| 147 | + """Ensure turn_complete=False triggers streaming translation path.""" |
| 148 | + |
| 149 | + streaming_calls = [] |
| 150 | + lro_calls = [] |
| 151 | + |
| 152 | + async def fake_translate(self, adk_event, thread_id, run_id): |
| 153 | + streaming_calls.append((adk_event, thread_id, run_id)) |
| 154 | + yield TextMessageChunkEvent( |
| 155 | + message_id=adk_event.id, |
| 156 | + role="assistant", |
| 157 | + delta="streamed chunk", |
| 158 | + ) |
| 159 | + |
| 160 | + async def fake_translate_lro(self, adk_event): |
| 161 | + lro_calls.append(adk_event) |
| 162 | + if False: # pragma: no cover - required to keep async generator signature |
| 163 | + yield None |
| 164 | + |
| 165 | + mock_event = Mock() |
| 166 | + mock_event.id = "event_stream" |
| 167 | + mock_event.author = "assistant" |
| 168 | + mock_event.partial = False |
| 169 | + mock_event.turn_complete = False |
| 170 | + mock_event.finish_reason = "STOP" |
| 171 | + mock_event.usage_metadata = {"tokens": 5} |
| 172 | + mock_event.is_final_response = Mock(return_value=True) |
| 173 | + mock_event.content = Mock() |
| 174 | + mock_event.content.parts = [Mock(text="Final response chunk")] |
| 175 | + mock_event.actions = None |
| 176 | + mock_event.get_function_calls = Mock(return_value=[]) |
| 177 | + mock_event.get_function_responses = Mock(return_value=[]) |
| 178 | + mock_event.custom_data = None |
| 179 | + |
| 180 | + class DummyRunner: |
| 181 | + async def run_async(self, *args, **kwargs): |
| 182 | + yield mock_event |
| 183 | + |
| 184 | + with patch.object(adk_agent, '_create_runner', return_value=DummyRunner()), \ |
| 185 | + patch.object(EventTranslator, 'translate', new=fake_translate), \ |
| 186 | + patch.object(EventTranslator, 'translate_lro_function_calls', new=fake_translate_lro): |
| 187 | + |
| 188 | + events = [] |
| 189 | + async for event in adk_agent.run(sample_input): |
| 190 | + events.append(event) |
| 191 | + |
| 192 | + # Verify run lifecycle events emitted |
| 193 | + assert events[0].type == EventType.RUN_STARTED |
| 194 | + assert events[-1].type == EventType.RUN_FINISHED |
| 195 | + |
| 196 | + # Ensure streaming translator branch handled the event |
| 197 | + chunk_events = [event for event in events if isinstance(event, TextMessageChunkEvent)] |
| 198 | + assert chunk_events, "Expected translated chunk event" |
| 199 | + assert chunk_events[0].delta == "streamed chunk" |
| 200 | + |
| 201 | + # Confirm branch selection |
| 202 | + assert len(streaming_calls) == 1 |
| 203 | + assert lro_calls == [] |
| 204 | + |
142 | 205 | async def test_partial_final_chunk_uses_streaming_translation(self, adk_agent, sample_input):
|
143 | 206 | """Ensure partial chunks marked as final still use streaming translation."""
|
144 | 207 |
|
|
0 commit comments