|
| 1 | +# mypy: ignore-errors |
| 2 | +import asyncio |
| 3 | +import logging |
| 4 | + |
| 5 | +from collections.abc import AsyncGenerator, AsyncIterable |
| 6 | +from typing import Any |
| 7 | +from uuid import uuid4 |
| 8 | + |
| 9 | +from google.adk import Runner |
| 10 | +from google.adk.agents import LlmAgent, RunConfig |
| 11 | +from google.adk.artifacts import InMemoryArtifactService |
| 12 | +from google.adk.events import Event |
| 13 | +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService |
| 14 | +from google.adk.sessions import InMemorySessionService |
| 15 | +from google.genai import types as genai_types |
| 16 | +from pydantic import ConfigDict |
| 17 | + |
| 18 | +from a2a.client import A2AClient |
| 19 | +from a2a.server.agent_execution import AgentExecutor, RequestContext |
| 20 | +from a2a.server.events.event_queue import EventQueue |
| 21 | +from a2a.server.tasks import TaskUpdater |
| 22 | +from a2a.types import ( |
| 23 | + Artifact, |
| 24 | + FilePart, |
| 25 | + FileWithBytes, |
| 26 | + FileWithUri, |
| 27 | + GetTaskRequest, |
| 28 | + GetTaskSuccessResponse, |
| 29 | + Message, |
| 30 | + MessageSendParams, |
| 31 | + Part, |
| 32 | + Role, |
| 33 | + SendMessageRequest, |
| 34 | + SendMessageSuccessResponse, |
| 35 | + Task, |
| 36 | + TaskQueryParams, |
| 37 | + TaskState, |
| 38 | + TaskStatus, |
| 39 | + TextPart, |
| 40 | + UnsupportedOperationError, |
| 41 | +) |
| 42 | +from a2a.utils import get_text_parts |
| 43 | +from a2a.utils.errors import ServerError |
| 44 | + |
| 45 | +from adk_agent import create_french_translation_agent |
| 46 | + |
| 47 | + |
| 48 | +logger = logging.getLogger(__name__) |
| 49 | +logger.setLevel(logging.DEBUG) |
| 50 | + |
| 51 | + |
| 52 | +class ADKFrenchTranslationAgentExecutor(AgentExecutor): |
| 53 | + """An AgentExecutor that runs an ADK-based French Translation Agent.""" |
| 54 | + |
| 55 | + def __init__(self): |
| 56 | + # Initialize the ADK agent and runner. |
| 57 | + self._agent = asyncio.run(create_french_translation_agent()) |
| 58 | + self.runner = Runner( |
| 59 | + app_name=self._agent.name, |
| 60 | + agent=self._agent, |
| 61 | + artifact_service=InMemoryArtifactService(), |
| 62 | + session_service=InMemorySessionService(), |
| 63 | + memory_service=InMemoryMemoryService(), |
| 64 | + ) |
| 65 | + |
| 66 | + def _run_agent( |
| 67 | + self, |
| 68 | + session_id: str, |
| 69 | + new_message: genai_types.Content, |
| 70 | + task_updater: TaskUpdater, # This parameter is not used in this method. |
| 71 | + ) -> AsyncGenerator[Event, None]: |
| 72 | + """Runs the ADK agent with the given message.""" |
| 73 | + return self.runner.run_async( |
| 74 | + session_id=session_id, |
| 75 | + user_id='self', # The user ID for the ADK session. |
| 76 | + new_message=new_message, |
| 77 | + ) |
| 78 | + |
| 79 | + async def _process_request( |
| 80 | + self, |
| 81 | + new_message: genai_types.Content, |
| 82 | + session_id: str, |
| 83 | + task_updater: TaskUpdater, |
| 84 | + ) -> AsyncIterable[TaskStatus | Artifact]: |
| 85 | + """Processes the incoming request by running the ADK agent.""" |
| 86 | + session = await self._upsert_session( |
| 87 | + session_id, |
| 88 | + ) |
| 89 | + session_id = session.id |
| 90 | + async for event in self._run_agent( |
| 91 | + session_id, new_message, task_updater # Pass task_updater to _run_agent |
| 92 | + ): |
| 93 | + logger.debug('Received ADK event: %s', event) |
| 94 | + if event.is_final_response(): |
| 95 | + # If the ADK agent provides a final response, convert it to A2A artifact and complete the task. |
| 96 | + response = convert_genai_parts_to_a2a(event.content.parts) |
| 97 | + logger.debug('Yielding final response: %s', response) |
| 98 | + task_updater.add_artifact(response) |
| 99 | + task_updater.complete() |
| 100 | + break |
| 101 | + elif not event.get_function_calls(): |
| 102 | + # If it's not a final response and no function calls, it's an interim update. |
| 103 | + logger.debug('Yielding update response') |
| 104 | + task_updater.update_status( |
| 105 | + TaskState.working, |
| 106 | + message=task_updater.new_agent_message( |
| 107 | + convert_genai_parts_to_a2a(event.content.parts) |
| 108 | + ), |
| 109 | + ) |
| 110 | + else: |
| 111 | + # This agent does not use tools, so function calls are unexpected. |
| 112 | + logger.debug('Skipping event with function call: %s', event.get_function_calls()) |
| 113 | + |
| 114 | + |
| 115 | + async def execute( |
| 116 | + self, |
| 117 | + context: RequestContext, |
| 118 | + event_queue: EventQueue, |
| 119 | + ): |
| 120 | + """Executes the agent's logic based on the incoming A2A request.""" |
| 121 | + updater = TaskUpdater(event_queue, context.task_id, context.context_id) |
| 122 | + if not context.current_task: |
| 123 | + updater.submit() |
| 124 | + updater.start_work() |
| 125 | + await self._process_request( |
| 126 | + genai_types.UserContent( |
| 127 | + parts=convert_a2a_parts_to_genai(context.message.parts), |
| 128 | + ), |
| 129 | + context.context_id, |
| 130 | + updater, |
| 131 | + ) |
| 132 | + |
| 133 | + async def cancel(self, context: RequestContext, event_queue: EventQueue): |
| 134 | + raise ServerError(error=UnsupportedOperationError()) |
| 135 | + |
| 136 | + async def _upsert_session(self, session_id: str): |
| 137 | + """Retrieves or creates an ADK session.""" |
| 138 | + return await self.runner.session_service.get_session( |
| 139 | + app_name=self.runner.app_name, user_id='self', session_id=session_id |
| 140 | + ) or await self.runner.session_service.create_session( |
| 141 | + app_name=self.runner.app_name, user_id='self', session_id=session_id |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +def convert_a2a_parts_to_genai(parts: list[Part]) -> list[genai_types.Part]: |
| 146 | + """Converts a list of A2A Part objects to a list of Google GenAI Part objects.""" |
| 147 | + return [convert_a2a_part_to_genai(part) for part in parts] |
| 148 | + |
| 149 | + |
| 150 | +def convert_a2a_part_to_genai(part: Part) -> genai_types.Part: |
| 151 | + """Converts a single A2A Part object to a Google GenAI Part object.""" |
| 152 | + part = part.root |
| 153 | + if isinstance(part, TextPart): |
| 154 | + return genai_types.Part(text=part.text) |
| 155 | + if isinstance(part, FilePart): |
| 156 | + if isinstance(part.file, FileWithUri): |
| 157 | + return genai_types.Part( |
| 158 | + file_data=genai_types.FileData( |
| 159 | + file_uri=part.file.uri, mime_type=part.file.mime_type |
| 160 | + ) |
| 161 | + ) |
| 162 | + if isinstance(part.file, FileWithBytes): |
| 163 | + return genai_types.Part( |
| 164 | + inline_data=genai_types.Blob( |
| 165 | + data=part.file.bytes, mime_type=part.file.mime_type |
| 166 | + ) |
| 167 | + ) |
| 168 | + raise ValueError(f'Unsupported file type: {type(part.file)}') |
| 169 | + raise ValueError(f'Unsupported part type: {type(part)}') |
| 170 | + |
| 171 | + |
| 172 | +def convert_genai_parts_to_a2a(parts: list[genai_types.Part]) -> list[Part]: |
| 173 | + """Converts a list of Google GenAI Part objects to a list of A2A Part objects.""" |
| 174 | + return [ |
| 175 | + convert_genai_part_to_a2a(part) |
| 176 | + for part in parts |
| 177 | + if (part.text or part.file_data or part.inline_data) |
| 178 | + ] |
| 179 | + |
| 180 | + |
| 181 | +def convert_genai_part_to_a2a(part: genai_types.Part) -> Part: |
| 182 | + """Converts a single Google GenAI Part object to an A2A Part object.""" |
| 183 | + if part.text: |
| 184 | + return TextPart(text=part.text) |
| 185 | + if part.file_data: |
| 186 | + return FilePart( |
| 187 | + file=FileWithUri( |
| 188 | + uri=part.file_data.file_uri, |
| 189 | + mime_type=part.file_data.mime_type, |
| 190 | + ) |
| 191 | + ) |
| 192 | + if part.inline_data: |
| 193 | + return Part( |
| 194 | + root=FilePart( |
| 195 | + file=FileWithBytes( |
| 196 | + bytes=part.inline_data.data, |
| 197 | + mime_type=part.inline_data.mime_type, |
| 198 | + ) |
| 199 | + ) |
| 200 | + ) |
| 201 | + raise ValueError(f'Unsupported part type: {part}') |
0 commit comments