Skip to content

Commit 5d1055a

Browse files
fix(adk): improve tool response serialization (#428)
1 parent 89dff48 commit 5d1055a

File tree

2 files changed

+167
-7
lines changed

2 files changed

+167
-7
lines changed

typescript-sdk/integrations/adk-middleware/python/src/ag_ui_adk/event_translator.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
"""Event translator for converting ADK events to AG-UI protocol events."""
44

5+
import dataclasses
6+
from collections.abc import Iterable, Mapping
57
from typing import AsyncGenerator, Optional, Dict, Any , List
68
import uuid
79

@@ -21,6 +23,106 @@
2123
logger = logging.getLogger(__name__)
2224

2325

26+
def _coerce_tool_response(value: Any, _visited: Optional[set[int]] = None) -> Any:
27+
"""Recursively convert arbitrary tool responses into JSON-serializable structures."""
28+
29+
if isinstance(value, (str, int, float, bool)) or value is None:
30+
return value
31+
32+
if isinstance(value, (bytes, bytearray, memoryview)):
33+
try:
34+
return value.decode() # type: ignore[union-attr]
35+
except Exception:
36+
return list(value)
37+
38+
if _visited is None:
39+
_visited = set()
40+
41+
obj_id = id(value)
42+
if obj_id in _visited:
43+
return str(value)
44+
45+
_visited.add(obj_id)
46+
try:
47+
if dataclasses.is_dataclass(value) and not isinstance(value, type):
48+
return {
49+
field.name: _coerce_tool_response(getattr(value, field.name), _visited)
50+
for field in dataclasses.fields(value)
51+
}
52+
53+
if hasattr(value, "_asdict") and callable(getattr(value, "_asdict")):
54+
try:
55+
return {
56+
str(k): _coerce_tool_response(v, _visited)
57+
for k, v in value._asdict().items() # type: ignore[attr-defined]
58+
}
59+
except Exception:
60+
pass
61+
62+
for method_name in ("model_dump", "to_dict"):
63+
method = getattr(value, method_name, None)
64+
if callable(method):
65+
try:
66+
dumped = method()
67+
except TypeError:
68+
try:
69+
dumped = method(exclude_none=False)
70+
except Exception:
71+
continue
72+
except Exception:
73+
continue
74+
75+
return _coerce_tool_response(dumped, _visited)
76+
77+
if isinstance(value, Mapping):
78+
return {
79+
str(k): _coerce_tool_response(v, _visited)
80+
for k, v in value.items()
81+
}
82+
83+
if isinstance(value, (list, tuple, set, frozenset)):
84+
return [_coerce_tool_response(item, _visited) for item in value]
85+
86+
if isinstance(value, Iterable):
87+
try:
88+
return [_coerce_tool_response(item, _visited) for item in list(value)]
89+
except TypeError:
90+
pass
91+
92+
try:
93+
obj_vars = vars(value)
94+
except TypeError:
95+
obj_vars = None
96+
97+
if obj_vars:
98+
coerced = {
99+
key: _coerce_tool_response(val, _visited)
100+
for key, val in obj_vars.items()
101+
if not key.startswith("_")
102+
}
103+
if coerced:
104+
return coerced
105+
106+
return str(value)
107+
finally:
108+
_visited.discard(obj_id)
109+
110+
111+
def _serialize_tool_response(response: Any) -> str:
112+
"""Serialize a tool response into a JSON string."""
113+
114+
try:
115+
coerced = _coerce_tool_response(response)
116+
return json.dumps(coerced, ensure_ascii=False)
117+
except Exception as exc:
118+
logger.warning("Failed to coerce tool response to JSON: %s", exc, exc_info=True)
119+
try:
120+
return json.dumps(str(response), ensure_ascii=False)
121+
except Exception:
122+
logger.warning("Failed to stringify tool response; returning empty string.")
123+
return json.dumps("", ensure_ascii=False)
124+
125+
24126
class EventTranslator:
25127
"""Translates Google ADK events to AG-UI protocol events.
26128
@@ -377,7 +479,7 @@ async def _translate_function_response(
377479
message_id=str(uuid.uuid4()),
378480
type=EventType.TOOL_CALL_RESULT,
379481
tool_call_id=tool_call_id,
380-
content=json.dumps(func_response.response)
482+
content=_serialize_tool_response(func_response.response)
381483
)
382484
else:
383485
logger.debug(f"Skipping ToolCallResultEvent for long-running tool: {tool_call_id}")

typescript-sdk/integrations/adk-middleware/python/tests/test_event_translator_comprehensive.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
#!/usr/bin/env python
22
"""Comprehensive tests for EventTranslator, focusing on untested paths."""
33

4+
import json
5+
from dataclasses import asdict, dataclass
6+
from types import SimpleNamespace
7+
48
import pytest
59
import uuid
610
from unittest.mock import MagicMock, patch, AsyncMock
711

812
from ag_ui.core import (
913
EventType, TextMessageStartEvent, TextMessageContentEvent, TextMessageEndEvent,
10-
ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, StateDeltaEvent, CustomEvent
14+
ToolCallStartEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallResultEvent,
15+
StateDeltaEvent, CustomEvent
1116
)
1217
from google.adk.events import Event as ADKEvent
1318
from ag_ui_adk.event_translator import EventTranslator
@@ -109,15 +114,68 @@ async def test_translate_function_calls_detection(self, translator, mock_adk_eve
109114
async def test_translate_function_responses_handling(self, translator, mock_adk_event):
110115
"""Test function responses handling."""
111116
# Mock event with function responses
112-
mock_function_response = MagicMock()
113-
mock_adk_event.get_function_responses = MagicMock(return_value=[mock_function_response])
117+
function_response = SimpleNamespace(id="tool-1", response={"ok": True})
118+
mock_adk_event.get_function_calls = MagicMock(return_value=[])
119+
mock_adk_event.get_function_responses = MagicMock(return_value=[function_response])
114120

115121
events = []
116122
async for event in translator.translate(mock_adk_event, "thread_1", "run_1"):
117123
events.append(event)
118124

119-
# Function responses should be handled but not emit events
120-
assert len(events) == 0
125+
assert len(events) == 1
126+
event = events[0]
127+
assert isinstance(event, ToolCallResultEvent)
128+
assert json.loads(event.content) == {"ok": True}
129+
130+
@pytest.mark.asyncio
131+
async def test_translate_function_response_with_call_tool_result_payload(self, translator):
132+
"""Ensure complex CallToolResult payloads are serialized correctly."""
133+
134+
@dataclass
135+
class TextContent:
136+
type: str = "text"
137+
text: str = ""
138+
annotations: list | None = None
139+
meta: dict | None = None
140+
141+
@dataclass
142+
class CallToolResult:
143+
meta: dict | None
144+
structuredContent: dict | None
145+
isError: bool
146+
content: list[TextContent]
147+
148+
repeated_text_entries = [
149+
"Primary Task: Provide a detailed walkthrough for the requested topic.",
150+
"Primary Task: Provide a detailed walkthrough for the requested topic.",
151+
"Constraints: Ensure clarity and maintain a concise explanation.",
152+
"Constraints: Ensure clarity and maintain a concise explanation.",
153+
]
154+
155+
payload = CallToolResult(
156+
meta=None,
157+
structuredContent=None,
158+
isError=False,
159+
content=[TextContent(text=text) for text in repeated_text_entries],
160+
)
161+
162+
function_response = SimpleNamespace(
163+
id="tool-structured-1",
164+
response={"result": payload},
165+
)
166+
167+
events = []
168+
async for event in translator._translate_function_response([function_response]):
169+
events.append(event)
170+
171+
assert len(events) == 1
172+
event = events[0]
173+
assert isinstance(event, ToolCallResultEvent)
174+
175+
content = json.loads(event.content)
176+
assert content["result"]["isError"] is False
177+
assert content["result"]["structuredContent"] is None
178+
assert [item["text"] for item in content["result"]["content"]] == repeated_text_entries
121179

122180
@pytest.mark.asyncio
123181
async def test_translate_state_delta_event(self, translator, mock_adk_event):
@@ -781,4 +839,4 @@ async def test_partial_streaming_continuation(self, translator, mock_adk_event_w
781839

782840
# Should reset streaming state
783841
assert translator._is_streaming is False
784-
assert translator._streaming_message_id is None
842+
assert translator._streaming_message_id is None

0 commit comments

Comments
 (0)