|
2 | 2 |
|
3 | 3 | """Event translator for converting ADK events to AG-UI protocol events."""
|
4 | 4 |
|
| 5 | +import dataclasses |
| 6 | +from collections.abc import Iterable, Mapping |
5 | 7 | from typing import AsyncGenerator, Optional, Dict, Any , List
|
6 | 8 | import uuid
|
7 | 9 |
|
|
21 | 23 | logger = logging.getLogger(__name__)
|
22 | 24 |
|
23 | 25 |
|
| 26 | +def _coerce_tool_response(value: Any, _visited: Optional[set[int]] = None) -> Any: |
| 27 | + """Recursively convert arbitrary tool responses into JSON-serializable structures.""" |
| 28 | + |
| 29 | + if isinstance(value, (str, int, float, bool)) or value is None: |
| 30 | + return value |
| 31 | + |
| 32 | + if isinstance(value, (bytes, bytearray, memoryview)): |
| 33 | + try: |
| 34 | + return value.decode() # type: ignore[union-attr] |
| 35 | + except Exception: |
| 36 | + return list(value) |
| 37 | + |
| 38 | + if _visited is None: |
| 39 | + _visited = set() |
| 40 | + |
| 41 | + obj_id = id(value) |
| 42 | + if obj_id in _visited: |
| 43 | + return str(value) |
| 44 | + |
| 45 | + _visited.add(obj_id) |
| 46 | + try: |
| 47 | + if dataclasses.is_dataclass(value) and not isinstance(value, type): |
| 48 | + return { |
| 49 | + field.name: _coerce_tool_response(getattr(value, field.name), _visited) |
| 50 | + for field in dataclasses.fields(value) |
| 51 | + } |
| 52 | + |
| 53 | + if hasattr(value, "_asdict") and callable(getattr(value, "_asdict")): |
| 54 | + try: |
| 55 | + return { |
| 56 | + str(k): _coerce_tool_response(v, _visited) |
| 57 | + for k, v in value._asdict().items() # type: ignore[attr-defined] |
| 58 | + } |
| 59 | + except Exception: |
| 60 | + pass |
| 61 | + |
| 62 | + for method_name in ("model_dump", "to_dict"): |
| 63 | + method = getattr(value, method_name, None) |
| 64 | + if callable(method): |
| 65 | + try: |
| 66 | + dumped = method() |
| 67 | + except TypeError: |
| 68 | + try: |
| 69 | + dumped = method(exclude_none=False) |
| 70 | + except Exception: |
| 71 | + continue |
| 72 | + except Exception: |
| 73 | + continue |
| 74 | + |
| 75 | + return _coerce_tool_response(dumped, _visited) |
| 76 | + |
| 77 | + if isinstance(value, Mapping): |
| 78 | + return { |
| 79 | + str(k): _coerce_tool_response(v, _visited) |
| 80 | + for k, v in value.items() |
| 81 | + } |
| 82 | + |
| 83 | + if isinstance(value, (list, tuple, set, frozenset)): |
| 84 | + return [_coerce_tool_response(item, _visited) for item in value] |
| 85 | + |
| 86 | + if isinstance(value, Iterable): |
| 87 | + try: |
| 88 | + return [_coerce_tool_response(item, _visited) for item in list(value)] |
| 89 | + except TypeError: |
| 90 | + pass |
| 91 | + |
| 92 | + try: |
| 93 | + obj_vars = vars(value) |
| 94 | + except TypeError: |
| 95 | + obj_vars = None |
| 96 | + |
| 97 | + if obj_vars: |
| 98 | + coerced = { |
| 99 | + key: _coerce_tool_response(val, _visited) |
| 100 | + for key, val in obj_vars.items() |
| 101 | + if not key.startswith("_") |
| 102 | + } |
| 103 | + if coerced: |
| 104 | + return coerced |
| 105 | + |
| 106 | + return str(value) |
| 107 | + finally: |
| 108 | + _visited.discard(obj_id) |
| 109 | + |
| 110 | + |
| 111 | +def _serialize_tool_response(response: Any) -> str: |
| 112 | + """Serialize a tool response into a JSON string.""" |
| 113 | + |
| 114 | + try: |
| 115 | + coerced = _coerce_tool_response(response) |
| 116 | + return json.dumps(coerced, ensure_ascii=False) |
| 117 | + except Exception as exc: |
| 118 | + logger.warning("Failed to coerce tool response to JSON: %s", exc, exc_info=True) |
| 119 | + try: |
| 120 | + return json.dumps(str(response), ensure_ascii=False) |
| 121 | + except Exception: |
| 122 | + logger.warning("Failed to stringify tool response; returning empty string.") |
| 123 | + return json.dumps("", ensure_ascii=False) |
| 124 | + |
| 125 | + |
24 | 126 | class EventTranslator:
|
25 | 127 | """Translates Google ADK events to AG-UI protocol events.
|
26 | 128 |
|
@@ -377,7 +479,7 @@ async def _translate_function_response(
|
377 | 479 | message_id=str(uuid.uuid4()),
|
378 | 480 | type=EventType.TOOL_CALL_RESULT,
|
379 | 481 | tool_call_id=tool_call_id,
|
380 |
| - content=json.dumps(func_response.response) |
| 482 | + content=_serialize_tool_response(func_response.response) |
381 | 483 | )
|
382 | 484 | else:
|
383 | 485 | logger.debug(f"Skipping ToolCallResultEvent for long-running tool: {tool_call_id}")
|
|
0 commit comments