|
| 1 | +# Copyright 2024 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import asyncio |
| 16 | +import httpx |
| 17 | +import uuid # For generating unique IDs in the test block |
| 18 | + |
| 19 | +# Core imports from the a2a framework |
| 20 | +from a2a.client.client import A2AClient, A2AClientTaskInfo |
| 21 | +from a2a.server.agent_execution.agent_executor import AgentExecutor |
| 22 | +from a2a.types import Message, Part, Role, TextPart # Core types |
| 23 | +from a2a.utils.message import new_agent_text_message, get_message_text, new_user_text_message # Message utilities |
| 24 | + |
| 25 | + |
| 26 | +class HostAgent(AgentExecutor): |
| 27 | + """ |
| 28 | + An agent that orchestrates calls to PlanAgent, SearchAgent, and ReportAgent |
| 29 | + to process a user's task. |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + plan_agent_url: str, |
| 35 | + search_agent_url: str, |
| 36 | + report_agent_url: str, |
| 37 | + name: str = "HostAgent", |
| 38 | + ): |
| 39 | + super().__init__(name=name) |
| 40 | + self.plan_agent_url = plan_agent_url |
| 41 | + self.search_agent_url = search_agent_url |
| 42 | + self.report_agent_url = report_agent_url |
| 43 | + # A2AClients will be initialized within execute, along with httpx.AsyncClient |
| 44 | + |
| 45 | + async def _call_sub_agent( |
| 46 | + self, |
| 47 | + client: A2AClient, |
| 48 | + agent_name: str, # For logging/error messages |
| 49 | + input_text: str, |
| 50 | + original_message: Message, # To carry over contextId, taskId |
| 51 | + ) -> str: |
| 52 | + """Helper to call a sub-agent and extract its text response.""" |
| 53 | + # Create a message to send to the sub-agent. |
| 54 | + # It's a "user" message from the perspective of the sub-agent. |
| 55 | + # However, the A2AClient might wrap this in a Task structure. |
| 56 | + # The A2AClient's execute_agent_task expects a list of Message objects as input. |
| 57 | + sub_agent_input_message = new_user_text_message( # HostAgent acts as a "user" to sub-agents |
| 58 | + text=input_text, |
| 59 | + context_id=original_message.contextId, # Propagate context |
| 60 | + task_id=original_message.taskId, # Propagate task |
| 61 | + ) |
| 62 | + |
| 63 | + try: |
| 64 | + # The A2AClient.execute_agent_task expects a list of Messages |
| 65 | + # and returns an A2AClientTaskInfo object. |
| 66 | + task_info: A2AClientTaskInfo = await client.execute_agent_task( |
| 67 | + messages=[sub_agent_input_message] |
| 68 | + ) |
| 69 | + |
| 70 | + # The final message from the sub-agent is often in task_info.result.messages |
| 71 | + if task_info.result and task_info.result.messages: |
| 72 | + # Assuming the last message is the agent's response |
| 73 | + agent_response_message = task_info.result.messages[-1] |
| 74 | + if agent_response_message.role == Role.AGENT: |
| 75 | + return get_message_text(agent_response_message) |
| 76 | + else: |
| 77 | + return f"Error: {agent_name} did not respond with an AGENT message." |
| 78 | + else: |
| 79 | + return f"Error: No response messages from {agent_name}." |
| 80 | + |
| 81 | + except Exception as e: |
| 82 | + # Log the exception or handle it more gracefully |
| 83 | + print(f"Error calling {agent_name} at {client._server_url}: {e}") |
| 84 | + return f"Error: Could not get response from {agent_name} due to {type(e).__name__}." |
| 85 | + |
| 86 | + |
| 87 | + async def execute(self, message: Message) -> Message: |
| 88 | + """ |
| 89 | + Orchestrates the sub-agents to process the task. |
| 90 | + """ |
| 91 | + task_description = get_message_text(message) |
| 92 | + if not task_description: |
| 93 | + return new_agent_text_message( |
| 94 | + text="Error: HostAgent received a message with no task description.", |
| 95 | + context_id=message.contextId, |
| 96 | + task_id=message.taskId, |
| 97 | + ) |
| 98 | + |
| 99 | + final_report = "Error: Orchestration failed." # Default error message |
| 100 | + |
| 101 | + async with httpx.AsyncClient() as http_client: |
| 102 | + plan_agent_client = A2AClient(server_url=self.plan_agent_url, http_client=http_client) |
| 103 | + search_agent_client = A2AClient(server_url=self.search_agent_url, http_client=http_client) |
| 104 | + report_agent_client = A2AClient(server_url=self.report_agent_url, http_client=http_client) |
| 105 | + |
| 106 | + # 1. Call PlanAgent |
| 107 | + plan = await self._call_sub_agent( |
| 108 | + plan_agent_client, "PlanAgent", task_description, message |
| 109 | + ) |
| 110 | + if plan.startswith("Error:"): |
| 111 | + return new_agent_text_message(text=plan, context_id=message.contextId, task_id=message.taskId) |
| 112 | + |
| 113 | + # 2. Call SearchAgent |
| 114 | + # For simplicity, using the original task description as the search query. |
| 115 | + # A more advanced version might parse the plan to create specific queries. |
| 116 | + search_query = task_description |
| 117 | + search_results = await self._call_sub_agent( |
| 118 | + search_agent_client, "SearchAgent", search_query, message |
| 119 | + ) |
| 120 | + if search_results.startswith("Error:"): |
| 121 | + # Proceed with reporting what we have, or return error |
| 122 | + combined_input_for_report = f"Plan:\n{plan}\n\nSearch Results: Failed - {search_results}" |
| 123 | + else: |
| 124 | + combined_input_for_report = f"Plan:\n{plan}\n\nSearch Results:\n{search_results}" |
| 125 | + |
| 126 | + # 3. Call ReportAgent |
| 127 | + final_report = await self._call_sub_agent( |
| 128 | + report_agent_client, "ReportAgent", combined_input_for_report, message |
| 129 | + ) |
| 130 | + # If final_report itself is an error string from _call_sub_agent, it will be returned. |
| 131 | + |
| 132 | + # Return the final report from ReportAgent |
| 133 | + return new_agent_text_message( |
| 134 | + text=final_report, |
| 135 | + context_id=message.contextId, |
| 136 | + task_id=message.taskId, |
| 137 | + ) |
| 138 | + |
| 139 | + async def cancel(self, interaction_id: str) -> None: |
| 140 | + """ |
| 141 | + Cancels an ongoing task. |
| 142 | + For HostAgent, this would ideally involve propagating cancellations to sub-agents. |
| 143 | + """ |
| 144 | + print(f"Cancellation requested for interaction/context/task '{interaction_id}' in {self.name}.") |
| 145 | + # TODO: Implement cancellation propagation to sub-agents if their A2AClient interface supports it. |
| 146 | + # For now, this is a placeholder. |
| 147 | + raise NotImplementedError( |
| 148 | + "HostAgent cancellation requires propagation to sub-agents, which is not yet implemented." |
| 149 | + ) |
| 150 | + |
| 151 | + |
| 152 | +if __name__ == "__main__": |
| 153 | + # This example is more complex to run directly as it involves HTTP calls |
| 154 | + # to other agents. For a simple test, we would mock A2AClient. |
| 155 | + |
| 156 | + # --- Mocking section --- |
| 157 | + class MockA2AClient: |
| 158 | + def __init__(self, server_url: str, http_client=None): |
| 159 | + self._server_url = server_url |
| 160 | + self.http_client = http_client # Keep httpx.AsyncClient for realism if used by HostAgent |
| 161 | + |
| 162 | + async def execute_agent_task(self, messages: list[Message]) -> A2AClientTaskInfo: |
| 163 | + input_text = get_message_text(messages[0]) |
| 164 | + # Simulate responses based on the agent URL or input |
| 165 | + response_text = "" |
| 166 | + if "plan" in self._server_url: |
| 167 | + response_text = f"Plan for '{input_text}': Step 1, Step 2." |
| 168 | + elif "search" in self._server_url: |
| 169 | + response_text = f"Search results for '{input_text}': Result A, Result B." |
| 170 | + elif "report" in self._server_url: |
| 171 | + response_text = f"Report based on: {input_text}" |
| 172 | + |
| 173 | + # Simulate A2AClientTaskInfo structure |
| 174 | + response_message = new_agent_text_message( |
| 175 | + text=response_text, |
| 176 | + context_id=messages[0].contextId, |
| 177 | + task_id=messages[0].taskId |
| 178 | + ) |
| 179 | + # Simplified TaskResult and A2AClientTaskInfo |
| 180 | + class MockTaskResult: |
| 181 | + def __init__(self, messages): |
| 182 | + self.messages = messages |
| 183 | + class MockA2AClientTaskInfo(A2AClientTaskInfo): |
| 184 | + def __init__(self, messages): |
| 185 | + super().__init__(task_id="", status="", messages=messages, result=MockTaskResult(messages=messages)) |
| 186 | + |
| 187 | + return MockA2AClientTaskInfo(messages=[response_message]) |
| 188 | + |
| 189 | + # Store original and apply mock |
| 190 | + original_a2a_client = A2AClient |
| 191 | + A2AClient = MockA2AClient # type: ignore |
| 192 | + |
| 193 | + # Mock AgentExecutor for HostAgent itself |
| 194 | + class MockAgentExecutor: |
| 195 | + def __init__(self, name: str): |
| 196 | + self.name = name |
| 197 | + original_agent_executor = AgentExecutor |
| 198 | + AgentExecutor = MockAgentExecutor # type: ignore |
| 199 | + # --- End Mocking section --- |
| 200 | + |
| 201 | + async def main_test(): |
| 202 | + # Dummy URLs for the mocked clients |
| 203 | + plan_url = "http://mockplanagent.test" |
| 204 | + search_url = "http://mocksearchagent.test" |
| 205 | + report_url = "http://mockreportagent.test" |
| 206 | + |
| 207 | + host_agent = HostAgent( |
| 208 | + plan_agent_url=plan_url, |
| 209 | + search_agent_url=search_url, |
| 210 | + report_agent_url=report_url, |
| 211 | + ) |
| 212 | + |
| 213 | + user_task = "Research benefits of async programming and report them." |
| 214 | + test_message = new_user_text_message( |
| 215 | + text=user_task, |
| 216 | + context_id=str(uuid.uuid4()), |
| 217 | + task_id=str(uuid.uuid4()) |
| 218 | + ) |
| 219 | + |
| 220 | + print(f"HostAgent processing task: '{user_task}'") |
| 221 | + final_response = await host_agent.execute(test_message) |
| 222 | + |
| 223 | + print("\nHostAgent Final Response:") |
| 224 | + print(get_message_text(final_response)) |
| 225 | + |
| 226 | + # Test cancellation (will raise NotImplementedError as per implementation) |
| 227 | + try: |
| 228 | + print("\nTesting HostAgent cancellation...") |
| 229 | + await host_agent.cancel(test_message.contextId) |
| 230 | + except NotImplementedError as e: |
| 231 | + print(f"Cancellation test: Caught expected error - {e}") |
| 232 | + |
| 233 | + try: |
| 234 | + asyncio.run(main_test()) |
| 235 | + finally: |
| 236 | + # Restore original classes |
| 237 | + A2AClient = original_a2a_client # type: ignore |
| 238 | + AgentExecutor = original_agent_executor # type: ignore |
| 239 | + print("\nRestored A2AClient and AgentExecutor. HostAgent example finished.") |
0 commit comments