diff --git a/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py b/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py index efb674e17..d08e13ec6 100644 --- a/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py +++ b/typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py @@ -2,6 +2,8 @@ """Event translator for converting ADK events to AG-UI protocol events.""" +import dataclasses +from collections.abc import Iterable, Mapping from typing import AsyncGenerator, Optional, Dict, Any , List import uuid @@ -21,6 +23,106 @@ logger = logging.getLogger(__name__) +def _coerce_tool_response(value: Any, _visited: Optional[set[int]] = None) -> Any: + """Recursively convert arbitrary tool responses into JSON-serializable structures.""" + + if isinstance(value, (str, int, float, bool)) or value is None: + return value + + if isinstance(value, (bytes, bytearray, memoryview)): + try: + return value.decode() # type: ignore[union-attr] + except Exception: + return list(value) + + if _visited is None: + _visited = set() + + obj_id = id(value) + if obj_id in _visited: + return str(value) + + _visited.add(obj_id) + try: + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return { + field.name: _coerce_tool_response(getattr(value, field.name), _visited) + for field in dataclasses.fields(value) + } + + if hasattr(value, "_asdict") and callable(getattr(value, "_asdict")): + try: + return { + str(k): _coerce_tool_response(v, _visited) + for k, v in value._asdict().items() # type: ignore[attr-defined] + } + except Exception: + pass + + for method_name in ("model_dump", "to_dict"): + method = getattr(value, method_name, None) + if callable(method): + try: + dumped = method() + except TypeError: + try: + dumped = method(exclude_none=False) + except Exception: + continue + except Exception: + continue + + return _coerce_tool_response(dumped, _visited) + + if isinstance(value, Mapping): + return { + str(k): _coerce_tool_response(v, _visited) + for k, v in value.items() + } + + if isinstance(value, (list, tuple, set, frozenset)): + return [_coerce_tool_response(item, _visited) for item in value] + + if isinstance(value, Iterable): + try: + return [_coerce_tool_response(item, _visited) for item in list(value)] + except TypeError: + pass + + try: + obj_vars = vars(value) + except TypeError: + obj_vars = None + + if obj_vars: + coerced = { + key: _coerce_tool_response(val, _visited) + for key, val in obj_vars.items() + if not key.startswith("_") + } + if coerced: + return coerced + + return str(value) + finally: + _visited.discard(obj_id) + + +def _serialize_tool_response(response: Any) -> str: + """Serialize a tool response into a JSON string.""" + + try: + coerced = _coerce_tool_response(response) + return json.dumps(coerced, ensure_ascii=False) + except Exception as exc: + logger.warning("Failed to coerce tool response to JSON: %s", exc, exc_info=True) + try: + return json.dumps(str(response), ensure_ascii=False) + except Exception: + logger.warning("Failed to stringify tool response; returning empty string.") + return json.dumps("", ensure_ascii=False) + + class EventTranslator: """Translates Google ADK events to AG-UI protocol events. @@ -377,7 +479,7 @@ async def _translate_function_response( message_id=str(uuid.uuid4()), type=EventType.TOOL_CALL_RESULT, tool_call_id=tool_call_id, - content=json.dumps(func_response.response) + content=_serialize_tool_response(func_response.response) ) else: logger.debug(f"Skipping ToolCallResultEvent for long-running tool: {tool_call_id}") diff --git a/typescript-sdk/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py b/typescript-sdk/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py index 8475cc8cb..6444c0781 100644 --- a/typescript-sdk/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py +++ b/typescript-sdk/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py @@ -1,13 +1,18 @@ #!/usr/bin/env python """Comprehensive tests for EventTranslator, focusing on untested paths.""" +import json +from dataclasses import asdict, dataclass +from types import SimpleNamespace + import pytest import uuid from unittest.mock import MagicMock, patch, AsyncMock from ag_ui.core import ( EventType, TextMessageStartEvent, TextMessageContentEvent, TextMessageEndEvent, - ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, StateDeltaEvent, CustomEvent + ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent, + StateDeltaEvent, CustomEvent ) from google.adk.events import Event as ADKEvent from ag_ui_adk.event_translator import EventTranslator @@ -109,15 +114,68 @@ async def test_translate_function_calls_detection(self, translator, mock_adk_eve async def test_translate_function_responses_handling(self, translator, mock_adk_event): """Test function responses handling.""" # Mock event with function responses - mock_function_response = MagicMock() - mock_adk_event.get_function_responses = MagicMock(return_value=[mock_function_response]) + function_response = SimpleNamespace(id="tool-1", response={"ok": True}) + mock_adk_event.get_function_calls = MagicMock(return_value=[]) + mock_adk_event.get_function_responses = MagicMock(return_value=[function_response]) events = [] async for event in translator.translate(mock_adk_event, "thread_1", "run_1"): events.append(event) - # Function responses should be handled but not emit events - assert len(events) == 0 + assert len(events) == 1 + event = events[0] + assert isinstance(event, ToolCallResultEvent) + assert json.loads(event.content) == {"ok": True} + + @pytest.mark.asyncio + async def test_translate_function_response_with_call_tool_result_payload(self, translator): + """Ensure complex CallToolResult payloads are serialized correctly.""" + + @dataclass + class TextContent: + type: str = "text" + text: str = "" + annotations: list | None = None + meta: dict | None = None + + @dataclass + class CallToolResult: + meta: dict | None + structuredContent: dict | None + isError: bool + content: list[TextContent] + + repeated_text_entries = [ + "Primary Task: Provide a detailed walkthrough for the requested topic.", + "Primary Task: Provide a detailed walkthrough for the requested topic.", + "Constraints: Ensure clarity and maintain a concise explanation.", + "Constraints: Ensure clarity and maintain a concise explanation.", + ] + + payload = CallToolResult( + meta=None, + structuredContent=None, + isError=False, + content=[TextContent(text=text) for text in repeated_text_entries], + ) + + function_response = SimpleNamespace( + id="tool-structured-1", + response={"result": payload}, + ) + + events = [] + async for event in translator._translate_function_response([function_response]): + events.append(event) + + assert len(events) == 1 + event = events[0] + assert isinstance(event, ToolCallResultEvent) + + content = json.loads(event.content) + assert content["result"]["isError"] is False + assert content["result"]["structuredContent"] is None + assert [item["text"] for item in content["result"]["content"]] == repeated_text_entries @pytest.mark.asyncio async def test_translate_state_delta_event(self, translator, mock_adk_event): @@ -781,4 +839,4 @@ async def test_partial_streaming_continuation(self, translator, mock_adk_event_w # Should reset streaming state assert translator._is_streaming is False - assert translator._streaming_message_id is None \ No newline at end of file + assert translator._streaming_message_id is None