diff --git a/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py b/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py index 922bd0279..2bcfbe16d 100644 --- a/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py +++ b/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py @@ -223,10 +223,16 @@ async def translate( # Handle state changes - if hasattr(adk_event, 'actions') and adk_event.actions and hasattr(adk_event.actions, 'state_delta') and adk_event.actions.state_delta: - yield self._create_state_delta_event( - adk_event.actions.state_delta, thread_id, run_id - ) + if hasattr(adk_event, 'actions') and adk_event.actions: + if hasattr(adk_event.actions, 'state_delta') and adk_event.actions.state_delta: + yield self._create_state_delta_event( + adk_event.actions.state_delta, thread_id, run_id + ) + + if hasattr(adk_event.actions, 'state_snapshot'): + state_snapshot = adk_event.actions.state_snapshot + if state_snapshot is not None: + yield self._create_state_snapshot_event(state_snapshot) # Handle custom events or metadata @@ -583,23 +589,9 @@ def _create_state_snapshot_event( A StateSnapshotEvent """ - FullSnapShot = { - "context": { - "conversation": [], - "user": { - "name": state_snapshot.get("user_name", ""), - "timezone": state_snapshot.get("timezone", "UTC") - }, - "app": { - "version": state_snapshot.get("app_version", "unknown") - } - }, - "state": state_snapshot.get("custom_state", {}) - } - return StateSnapshotEvent( type=EventType.STATE_SNAPSHOT, - snapshot=FullSnapShot + snapshot=state_snapshot ) async def force_close_streaming_message(self) -> AsyncGenerator[BaseEvent, None]: diff --git a/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py b/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py index 8ef9d17ea..7e7ea8d5b 100644 --- a/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py +++ b/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py @@ -13,7 +13,7 @@ from ag_ui.core import ( EventType, TextMessageStartEvent, TextMessageContentEvent, TextMessageEndEvent, ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent, - StateDeltaEvent, CustomEvent + StateDeltaEvent, StateSnapshotEvent, CustomEvent ) from google.adk.events import Event as ADKEvent from ag_ui_adk.event_translator import EventTranslator @@ -185,6 +185,7 @@ async def test_translate_state_delta_event(self, translator, mock_adk_event): # Mock event with state delta mock_actions = MagicMock() mock_actions.state_delta = {"key1": "value1", "key2": "value2"} + mock_actions.state_snapshot = None mock_adk_event.actions = mock_actions events = [] @@ -201,6 +202,55 @@ async def test_translate_state_delta_event(self, translator, mock_adk_event): assert any(patch["path"] == "/key1" and patch["value"] == "value1" for patch in patches) assert any(patch["path"] == "/key2" and patch["value"] == "value2" for patch in patches) + @pytest.mark.asyncio + async def test_translate_state_snapshot_event_passthrough(self, translator, mock_adk_event): + """Test state snapshot events preserve the ADK payload.""" + + state_snapshot = { + "user_name": "Alice", + "timezone": "UTC", + "custom_state": { + "view": {"active_tab": "details"}, + "progress": 0.75, + }, + "extra_field": [1, 2, 3], + } + + mock_adk_event.actions = SimpleNamespace( + state_delta=None, + state_snapshot=state_snapshot, + ) + + events = [] + async for event in translator.translate(mock_adk_event, "thread_1", "run_1"): + events.append(event) + + snapshot_events = [event for event in events if isinstance(event, StateSnapshotEvent)] + assert snapshot_events, "Expected a StateSnapshotEvent to be emitted" + + snapshot_event = snapshot_events[0] + assert snapshot_event.type == EventType.STATE_SNAPSHOT + assert snapshot_event.snapshot == state_snapshot + assert snapshot_event.snapshot["user_name"] == "Alice" + assert snapshot_event.snapshot["custom_state"]["view"]["active_tab"] == "details" + assert "extra_field" in snapshot_event.snapshot + + def test_create_state_snapshot_event_passthrough(self, translator): + """Direct helper should forward the snapshot unchanged.""" + + state_snapshot = { + "user_name": "Bob", + "custom_state": {"step": 3}, + "timezone": "PST", + } + + event = translator._create_state_snapshot_event(state_snapshot) + + assert isinstance(event, StateSnapshotEvent) + assert event.type == EventType.STATE_SNAPSHOT + assert event.snapshot == state_snapshot + assert set(event.snapshot.keys()) == {"user_name", "custom_state", "timezone"} + @pytest.mark.asyncio async def test_translate_custom_event(self, translator, mock_adk_event): """Test custom event creation.""" @@ -857,14 +907,15 @@ async def test_complex_event_with_multiple_features(self, translator, mock_adk_e async for event in translator.translate(mock_adk_event, "thread_1", "run_1"): events.append(event) - # Should have text events, state delta, and custom event - assert len(events) == 5 # START, CONTENT, STATE_DELTA, CUSTOM , END + # Should have text events, state delta, state snapshot, and custom event + assert len(events) == 6 # START, CONTENT, STATE_DELTA, STATE_SNAPSHOT, CUSTOM, END # Check event types event_types = [type(event) for event in events] assert TextMessageStartEvent in event_types assert TextMessageContentEvent in event_types assert StateDeltaEvent in event_types + assert StateSnapshotEvent in event_types assert CustomEvent in event_types assert TextMessageEndEvent in event_types