Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

"""Event translator for converting ADK events to AG-UI protocol events."""

import dataclasses
from collections.abc import Iterable, Mapping
from typing import AsyncGenerator, Optional, Dict, Any , List
import uuid

Expand All @@ -21,6 +23,106 @@
logger = logging.getLogger(__name__)


def _coerce_tool_response(value: Any, _visited: Optional[set[int]] = None) -> Any:
"""Recursively convert arbitrary tool responses into JSON-serializable structures."""

if isinstance(value, (str, int, float, bool)) or value is None:
return value

if isinstance(value, (bytes, bytearray, memoryview)):
try:
return value.decode() # type: ignore[union-attr]
except Exception:
return list(value)

if _visited is None:
_visited = set()

obj_id = id(value)
if obj_id in _visited:
return str(value)

_visited.add(obj_id)
try:
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return {
field.name: _coerce_tool_response(getattr(value, field.name), _visited)
for field in dataclasses.fields(value)
}

if hasattr(value, "_asdict") and callable(getattr(value, "_asdict")):
try:
return {
str(k): _coerce_tool_response(v, _visited)
for k, v in value._asdict().items() # type: ignore[attr-defined]
}
except Exception:
pass

for method_name in ("model_dump", "to_dict"):
method = getattr(value, method_name, None)
if callable(method):
try:
dumped = method()
except TypeError:
try:
dumped = method(exclude_none=False)
except Exception:
continue
except Exception:
continue

return _coerce_tool_response(dumped, _visited)

if isinstance(value, Mapping):
return {
str(k): _coerce_tool_response(v, _visited)
for k, v in value.items()
}

if isinstance(value, (list, tuple, set, frozenset)):
return [_coerce_tool_response(item, _visited) for item in value]

if isinstance(value, Iterable):
try:
return [_coerce_tool_response(item, _visited) for item in list(value)]
except TypeError:
pass

try:
obj_vars = vars(value)
except TypeError:
obj_vars = None

if obj_vars:
coerced = {
key: _coerce_tool_response(val, _visited)
for key, val in obj_vars.items()
if not key.startswith("_")
}
if coerced:
return coerced

return str(value)
finally:
_visited.discard(obj_id)


def _serialize_tool_response(response: Any) -> str:
"""Serialize a tool response into a JSON string."""

try:
coerced = _coerce_tool_response(response)
return json.dumps(coerced, ensure_ascii=False)
except Exception as exc:
logger.warning("Failed to coerce tool response to JSON: %s", exc, exc_info=True)
try:
return json.dumps(str(response), ensure_ascii=False)
except Exception:
logger.warning("Failed to stringify tool response; returning empty string.")
return json.dumps("", ensure_ascii=False)


class EventTranslator:
"""Translates Google ADK events to AG-UI protocol events.

Expand Down Expand Up @@ -377,7 +479,7 @@ async def _translate_function_response(
message_id=str(uuid.uuid4()),
type=EventType.TOOL_CALL_RESULT,
tool_call_id=tool_call_id,
content=json.dumps(func_response.response)
content=_serialize_tool_response(func_response.response)
)
else:
logger.debug(f"Skipping ToolCallResultEvent for long-running tool: {tool_call_id}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
#!/usr/bin/env python
"""Comprehensive tests for EventTranslator, focusing on untested paths."""

import json
from dataclasses import asdict, dataclass
from types import SimpleNamespace

import pytest
import uuid
from unittest.mock import MagicMock, patch, AsyncMock

from ag_ui.core import (
EventType, TextMessageStartEvent, TextMessageContentEvent, TextMessageEndEvent,
ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, StateDeltaEvent, CustomEvent
ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent,
StateDeltaEvent, CustomEvent
)
from google.adk.events import Event as ADKEvent
from ag_ui_adk.event_translator import EventTranslator
Expand Down Expand Up @@ -109,15 +114,68 @@ async def test_translate_function_calls_detection(self, translator, mock_adk_eve
async def test_translate_function_responses_handling(self, translator, mock_adk_event):
"""Test function responses handling."""
# Mock event with function responses
mock_function_response = MagicMock()
mock_adk_event.get_function_responses = MagicMock(return_value=[mock_function_response])
function_response = SimpleNamespace(id="tool-1", response={"ok": True})
mock_adk_event.get_function_calls = MagicMock(return_value=[])
mock_adk_event.get_function_responses = MagicMock(return_value=[function_response])

events = []
async for event in translator.translate(mock_adk_event, "thread_1", "run_1"):
events.append(event)

# Function responses should be handled but not emit events
assert len(events) == 0
assert len(events) == 1
event = events[0]
assert isinstance(event, ToolCallResultEvent)
assert json.loads(event.content) == {"ok": True}

@pytest.mark.asyncio
async def test_translate_function_response_with_call_tool_result_payload(self, translator):
"""Ensure complex CallToolResult payloads are serialized correctly."""

@dataclass
class TextContent:
type: str = "text"
text: str = ""
annotations: list | None = None
meta: dict | None = None

@dataclass
class CallToolResult:
meta: dict | None
structuredContent: dict | None
isError: bool
content: list[TextContent]

repeated_text_entries = [
"Primary Task: Provide a detailed walkthrough for the requested topic.",
"Primary Task: Provide a detailed walkthrough for the requested topic.",
"Constraints: Ensure clarity and maintain a concise explanation.",
"Constraints: Ensure clarity and maintain a concise explanation.",
]

payload = CallToolResult(
meta=None,
structuredContent=None,
isError=False,
content=[TextContent(text=text) for text in repeated_text_entries],
)

function_response = SimpleNamespace(
id="tool-structured-1",
response={"result": payload},
)

events = []
async for event in translator._translate_function_response([function_response]):
events.append(event)

assert len(events) == 1
event = events[0]
assert isinstance(event, ToolCallResultEvent)

content = json.loads(event.content)
assert content["result"]["isError"] is False
assert content["result"]["structuredContent"] is None
assert [item["text"] for item in content["result"]["content"]] == repeated_text_entries

@pytest.mark.asyncio
async def test_translate_state_delta_event(self, translator, mock_adk_event):
Expand Down Expand Up @@ -781,4 +839,4 @@ async def test_partial_streaming_continuation(self, translator, mock_adk_event_w

# Should reset streaming state
assert translator._is_streaming is False
assert translator._streaming_message_id is None
assert translator._streaming_message_id is None
Loading