|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import traceback |
| 4 | +from abc import ABC, abstractmethod |
| 5 | + |
| 6 | +from collections.abc import AsyncGenerator |
| 7 | +from typing import Any, Optional, Union |
| 8 | + |
| 9 | +from pydantic import ValidationError |
| 10 | +from sse_starlette.sse import EventSourceResponse |
| 11 | +from starlette.applications import Starlette |
| 12 | +from fastapi import FastAPI |
| 13 | +from starlette.requests import Request |
| 14 | +from starlette.responses import JSONResponse, Response |
| 15 | + |
| 16 | +from a2a.server.request_handlers.jsonrpc_handler import ( |
| 17 | + JSONRPCHandler, |
| 18 | + RequestHandler, |
| 19 | +) |
| 20 | +from a2a.types import ( |
| 21 | + A2AError, |
| 22 | + A2ARequest, |
| 23 | + AgentCard, |
| 24 | + CancelTaskRequest, |
| 25 | + GetTaskPushNotificationConfigRequest, |
| 26 | + GetTaskRequest, |
| 27 | + InternalError, |
| 28 | + InvalidRequestError, |
| 29 | + JSONParseError, |
| 30 | + JSONRPCError, |
| 31 | + JSONRPCErrorResponse, |
| 32 | + JSONRPCResponse, |
| 33 | + SendMessageRequest, |
| 34 | + SendStreamingMessageRequest, |
| 35 | + SendStreamingMessageResponse, |
| 36 | + SetTaskPushNotificationConfigRequest, |
| 37 | + TaskResubscriptionRequest, |
| 38 | + UnsupportedOperationError, |
| 39 | +) |
| 40 | +from a2a.utils.errors import MethodNotImplementedError |
| 41 | + |
| 42 | +logger = logging.getLogger(__name__) |
| 43 | + |
| 44 | + |
| 45 | +class DefaultA2AApplication(ABC): |
| 46 | + """Base class for A2A applications. |
| 47 | +
|
| 48 | + Args: |
| 49 | + agent_card: The AgentCard describing the agent's capabilities. |
| 50 | + http_handler: The handler instance responsible for processing A2A |
| 51 | + requests via http. |
| 52 | + """ |
| 53 | + |
| 54 | + def __init__(self, agent_card: AgentCard, http_handler: RequestHandler): |
| 55 | + """Initializes the A2AApplication. |
| 56 | +
|
| 57 | + Args: |
| 58 | + agent_card: The AgentCard describing the agent's capabilities. |
| 59 | + http_handler: The handler instance responsible for processing A2A |
| 60 | + requests via http. |
| 61 | + """ |
| 62 | + self.agent_card = agent_card |
| 63 | + self.handler = JSONRPCHandler( |
| 64 | + agent_card=agent_card, request_handler=http_handler |
| 65 | + ) |
| 66 | + |
| 67 | + def _generate_error_response( |
| 68 | + self, request_id: str | int | None, error: JSONRPCError | A2AError |
| 69 | + ) -> JSONResponse: |
| 70 | + """Creates a JSONResponse for a JSON-RPC error.""" |
| 71 | + error_resp = JSONRPCErrorResponse( |
| 72 | + id=request_id, |
| 73 | + error=error if isinstance(error, JSONRPCError) else error.root, |
| 74 | + ) |
| 75 | + |
| 76 | + log_level = ( |
| 77 | + logging.ERROR |
| 78 | + if not isinstance(error, A2AError) |
| 79 | + or isinstance(error.root, InternalError) |
| 80 | + else logging.WARNING |
| 81 | + ) |
| 82 | + logger.log( |
| 83 | + log_level, |
| 84 | + f'Request Error (ID: {request_id}: ' |
| 85 | + f"Code={error_resp.error.code}, Message='{error_resp.error.message}'" |
| 86 | + f'{", Data=" + str(error_resp.error.data) if hasattr(error, "data") and error_resp.error.data else ""}', |
| 87 | + ) |
| 88 | + return JSONResponse( |
| 89 | + error_resp.model_dump(mode='json', exclude_none=True), |
| 90 | + status_code=200, |
| 91 | + ) |
| 92 | + |
| 93 | + async def _handle_requests(self, request: Request) -> Response: |
| 94 | + """Handles incoming POST requests to the main A2A endpoint. |
| 95 | +
|
| 96 | + Parses the request body as JSON, validates it against A2A request types, |
| 97 | + dispatches it to the appropriate handler method, and returns the response. |
| 98 | + Handles JSON parsing errors, validation errors, and other exceptions, |
| 99 | + returning appropriate JSON-RPC error responses. |
| 100 | + """ |
| 101 | + request_id = None |
| 102 | + body = None |
| 103 | + |
| 104 | + try: |
| 105 | + body = await request.json() |
| 106 | + a2a_request = A2ARequest.model_validate(body) |
| 107 | + |
| 108 | + request_id = a2a_request.root.id |
| 109 | + request_obj = a2a_request.root |
| 110 | + |
| 111 | + if isinstance( |
| 112 | + request_obj, |
| 113 | + TaskResubscriptionRequest | SendStreamingMessageRequest, |
| 114 | + ): |
| 115 | + return await self._process_streaming_request( |
| 116 | + request_id, a2a_request |
| 117 | + ) |
| 118 | + |
| 119 | + return await self._process_non_streaming_request( |
| 120 | + request_id, a2a_request |
| 121 | + ) |
| 122 | + except MethodNotImplementedError: |
| 123 | + traceback.print_exc() |
| 124 | + return self._generate_error_response( |
| 125 | + request_id, A2AError(root=UnsupportedOperationError()) |
| 126 | + ) |
| 127 | + except json.decoder.JSONDecodeError as e: |
| 128 | + traceback.print_exc() |
| 129 | + return self._generate_error_response( |
| 130 | + None, A2AError(root=JSONParseError(message=str(e))) |
| 131 | + ) |
| 132 | + except ValidationError as e: |
| 133 | + traceback.print_exc() |
| 134 | + return self._generate_error_response( |
| 135 | + request_id, |
| 136 | + A2AError(root=InvalidRequestError(data=json.loads(e.json()))), |
| 137 | + ) |
| 138 | + except Exception as e: |
| 139 | + logger.error(f'Unhandled exception: {e}') |
| 140 | + traceback.print_exc() |
| 141 | + return self._generate_error_response( |
| 142 | + request_id, A2AError(root=InternalError(message=str(e))) |
| 143 | + ) |
| 144 | + |
| 145 | + async def _process_streaming_request( |
| 146 | + self, request_id: str | int | None, a2a_request: A2ARequest |
| 147 | + ) -> Response: |
| 148 | + """Processes streaming requests. |
| 149 | +
|
| 150 | + Args: |
| 151 | + request_id: The ID of the request. |
| 152 | + a2a_request: The validated A2ARequest object. |
| 153 | + """ |
| 154 | + request_obj = a2a_request.root |
| 155 | + handler_result: Any = None |
| 156 | + if isinstance( |
| 157 | + request_obj, |
| 158 | + SendStreamingMessageRequest, |
| 159 | + ): |
| 160 | + handler_result = self.handler.on_message_send_stream(request_obj) |
| 161 | + elif isinstance(request_obj, TaskResubscriptionRequest): |
| 162 | + handler_result = self.handler.on_resubscribe_to_task(request_obj) |
| 163 | + |
| 164 | + return self._create_response(handler_result) |
| 165 | + |
| 166 | + async def _process_non_streaming_request( |
| 167 | + self, request_id: str | int | None, a2a_request: A2ARequest |
| 168 | + ) -> Response: |
| 169 | + """Processes non-streaming requests. |
| 170 | +
|
| 171 | + Args: |
| 172 | + request_id: The ID of the request. |
| 173 | + a2a_request: The validated A2ARequest object. |
| 174 | + """ |
| 175 | + request_obj = a2a_request.root |
| 176 | + handler_result: Any = None |
| 177 | + match request_obj: |
| 178 | + case SendMessageRequest(): |
| 179 | + handler_result = await self.handler.on_message_send(request_obj) |
| 180 | + case CancelTaskRequest(): |
| 181 | + handler_result = await self.handler.on_cancel_task(request_obj) |
| 182 | + case GetTaskRequest(): |
| 183 | + handler_result = await self.handler.on_get_task(request_obj) |
| 184 | + case SetTaskPushNotificationConfigRequest(): |
| 185 | + handler_result = await self.handler.set_push_notification( |
| 186 | + request_obj |
| 187 | + ) |
| 188 | + case GetTaskPushNotificationConfigRequest(): |
| 189 | + handler_result = await self.handler.get_push_notification( |
| 190 | + request_obj |
| 191 | + ) |
| 192 | + case _: |
| 193 | + logger.error( |
| 194 | + f'Unhandled validated request type: {type(request_obj)}' |
| 195 | + ) |
| 196 | + error = UnsupportedOperationError( |
| 197 | + message=f'Request type {type(request_obj).__name__} is unknown.' |
| 198 | + ) |
| 199 | + handler_result = JSONRPCErrorResponse( |
| 200 | + id=request_id, error=error |
| 201 | + ) |
| 202 | + |
| 203 | + return self._create_response(handler_result) |
| 204 | + |
| 205 | + def _create_response( |
| 206 | + self, |
| 207 | + handler_result: ( |
| 208 | + AsyncGenerator[SendStreamingMessageResponse, None] |
| 209 | + | JSONRPCErrorResponse |
| 210 | + | JSONRPCResponse |
| 211 | + ), |
| 212 | + ) -> Response: |
| 213 | + """Creates a Starlette Response based on the result from the request handler. |
| 214 | +
|
| 215 | + Handles: |
| 216 | + - AsyncGenerator for Server-Sent Events (SSE). |
| 217 | + - JSONRPCErrorResponse for explicit errors returned by handlers. |
| 218 | + - Pydantic RootModels (like GetTaskResponse) containing success or error |
| 219 | + payloads. |
| 220 | + - Unexpected types by returning an InternalError. |
| 221 | +
|
| 222 | + Args: |
| 223 | + handler_result: AsyncGenerator of SendStreamingMessageResponse |
| 224 | +
|
| 225 | + Returns: |
| 226 | + A Starlette JSONResponse or EventSourceResponse. |
| 227 | + """ |
| 228 | + if isinstance(handler_result, AsyncGenerator): |
| 229 | + # Result is a stream of SendStreamingMessageResponse objects |
| 230 | + async def event_generator( |
| 231 | + stream: AsyncGenerator[SendStreamingMessageResponse, None], |
| 232 | + ) -> AsyncGenerator[dict[str, str], None]: |
| 233 | + async for item in stream: |
| 234 | + yield {'data': item.root.model_dump_json(exclude_none=True)} |
| 235 | + |
| 236 | + return EventSourceResponse(event_generator(handler_result)) |
| 237 | + if isinstance(handler_result, JSONRPCErrorResponse): |
| 238 | + return JSONResponse( |
| 239 | + handler_result.model_dump( |
| 240 | + mode='json', |
| 241 | + exclude_none=True, |
| 242 | + ) |
| 243 | + ) |
| 244 | + |
| 245 | + return JSONResponse( |
| 246 | + handler_result.root.model_dump(mode='json', exclude_none=True) |
| 247 | + ) |
| 248 | + |
| 249 | + async def _handle_get_agent_card(self, request: Request) -> JSONResponse: |
| 250 | + """Handles GET requests for the agent card.""" |
| 251 | + return JSONResponse( |
| 252 | + self.agent_card.model_dump(mode='json', exclude_none=True) |
| 253 | + ) |
| 254 | + |
| 255 | + @abstractmethod |
| 256 | + def build( |
| 257 | + self, |
| 258 | + agent_card_url: str = '/.well-known/agent.json', |
| 259 | + rpc_url: str = '/', |
| 260 | + **kwargs: Any, |
| 261 | + ) -> Union[Starlette, FastAPI]: |
| 262 | + """Builds and returns the FastAPI application instance. |
| 263 | +
|
| 264 | + Args: |
| 265 | + agent_card_url: The URL for the agent card endpoint. |
| 266 | + rpc_url: The URL for the A2A JSON-RPC endpoint |
| 267 | + **kwargs: Additional keyword arguments to pass to the FastAPI constructor. |
| 268 | +
|
| 269 | + Returns: |
| 270 | + A configured FastAPI application instance. |
| 271 | + """ |
| 272 | + pass |
0 commit comments