|
| 1 | +import functools |
| 2 | +import json |
| 3 | +import logging |
| 4 | +import traceback |
| 5 | + |
| 6 | +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +from pydantic import ValidationError |
| 10 | +from sse_starlette.sse import EventSourceResponse |
| 11 | +from starlette.requests import Request |
| 12 | +from starlette.responses import JSONResponse |
| 13 | + |
| 14 | +from a2a.server.apps.jsonrpc import ( |
| 15 | + CallContextBuilder, |
| 16 | + DefaultCallContextBuilder, |
| 17 | +) |
| 18 | +from a2a.server.context import ServerCallContext |
| 19 | +from a2a.server.request_handlers.request_handler import RequestHandler |
| 20 | +from a2a.server.request_handlers.rest_handler import ( |
| 21 | + RESTHandler, |
| 22 | +) |
| 23 | +from a2a.types import ( |
| 24 | + AgentCard, |
| 25 | + InternalError, |
| 26 | + InvalidRequestError, |
| 27 | + JSONParseError, |
| 28 | + UnsupportedOperationError, |
| 29 | +) |
| 30 | +from a2a.utils.errors import MethodNotImplementedError |
| 31 | + |
| 32 | + |
| 33 | +logger = logging.getLogger(__name__) |
| 34 | + |
| 35 | + |
| 36 | +class RESTApplication: |
| 37 | + """Base class for A2A REST applications. |
| 38 | +
|
| 39 | + Defines REST requests processors and the routes to attach them too, as well as |
| 40 | + manages response generation including Server-Sent Events (SSE). |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__( |
| 44 | + self, |
| 45 | + agent_card: AgentCard, |
| 46 | + http_handler: RequestHandler, |
| 47 | + context_builder: CallContextBuilder | None = None, |
| 48 | + ): |
| 49 | + """Initializes the RESTApplication. |
| 50 | +
|
| 51 | + Args: |
| 52 | + agent_card: The AgentCard describing the agent's capabilities. |
| 53 | + http_handler: The handler instance responsible for processing A2A |
| 54 | + requests via http. |
| 55 | + context_builder: The CallContextBuilder used to construct the |
| 56 | + ServerCallContext passed to the http_handler. If None, no |
| 57 | + ServerCallContext is passed. |
| 58 | + """ |
| 59 | + self.agent_card = agent_card |
| 60 | + self.handler = RESTHandler( |
| 61 | + agent_card=agent_card, request_handler=http_handler |
| 62 | + ) |
| 63 | + self._context_builder = context_builder or DefaultCallContextBuilder() |
| 64 | + |
| 65 | + def _generate_error_response(self, error) -> JSONResponse: |
| 66 | + """Creates a JSONResponse for a errors. |
| 67 | +
|
| 68 | + Logs the error based on its type. |
| 69 | +
|
| 70 | + Args: |
| 71 | + error: The Error object. |
| 72 | +
|
| 73 | + Returns: |
| 74 | + A `JSONResponse` object formatted as a JSON error response. |
| 75 | + """ |
| 76 | + log_level = ( |
| 77 | + logging.ERROR |
| 78 | + if isinstance(error, InternalError) |
| 79 | + else logging.WARNING |
| 80 | + ) |
| 81 | + logger.log( |
| 82 | + log_level, |
| 83 | + 'Request Error: ' |
| 84 | + f"Code={error.code}, Message='{error.message}'" |
| 85 | + f'{", Data=" + str(error.data) if error.data else ""}', |
| 86 | + ) |
| 87 | + return JSONResponse( |
| 88 | + '{"message": ' + error.message + '}', |
| 89 | + status_code=404, |
| 90 | + ) |
| 91 | + |
| 92 | + def _handle_error(self, error: Exception) -> JSONResponse: |
| 93 | + traceback.print_exc() |
| 94 | + if isinstance(error, MethodNotImplementedError): |
| 95 | + return self._generate_error_response(UnsupportedOperationError()) |
| 96 | + if isinstance(error, json.decoder.JSONDecodeError): |
| 97 | + return self._generate_error_response( |
| 98 | + JSONParseError(message=str(error)) |
| 99 | + ) |
| 100 | + if isinstance(error, ValidationError): |
| 101 | + return self._generate_error_response( |
| 102 | + InvalidRequestError(data=json.loads(error.json())), |
| 103 | + ) |
| 104 | + logger.error(f'Unhandled exception: {error}') |
| 105 | + return self._generate_error_response(InternalError(message=str(error))) |
| 106 | + |
| 107 | + async def _handle_request( |
| 108 | + self, |
| 109 | + method: Callable[[Request, ServerCallContext], Awaitable[str]], |
| 110 | + request: Request, |
| 111 | + ) -> JSONResponse: |
| 112 | + try: |
| 113 | + call_context = self._context_builder.build(request) |
| 114 | + response = await method(request, call_context) |
| 115 | + return JSONResponse(content=response) |
| 116 | + except Exception as e: |
| 117 | + return self._handle_error(e) |
| 118 | + |
| 119 | + async def _handle_streaming_request( |
| 120 | + self, |
| 121 | + method: Callable[[Request, ServerCallContext], AsyncIterator[str]], |
| 122 | + request: Request, |
| 123 | + ) -> EventSourceResponse: |
| 124 | + try: |
| 125 | + call_context = self._context_builder.build(request) |
| 126 | + |
| 127 | + async def event_generator( |
| 128 | + stream: AsyncGenerator[str], |
| 129 | + ) -> AsyncGenerator[dict[str, str]]: |
| 130 | + async for item in stream: |
| 131 | + yield {'data': item} |
| 132 | + |
| 133 | + return EventSourceResponse( |
| 134 | + event_generator(method(request, call_context)) |
| 135 | + ) |
| 136 | + except Exception as e: |
| 137 | + # Since the stream has started, we can't return a JSONResponse. |
| 138 | + # Instead, we runt the error handling logic (provides logging) |
| 139 | + # and reraise the error and let server framework manage |
| 140 | + self._handle_error(e) |
| 141 | + raise e |
| 142 | + |
| 143 | + async def _handle_get_agent_card(self, request: Request) -> JSONResponse: |
| 144 | + """Handles GET requests for the agent card endpoint. |
| 145 | +
|
| 146 | + Args: |
| 147 | + request: The incoming Starlette Request object. |
| 148 | +
|
| 149 | + Returns: |
| 150 | + A JSONResponse containing the agent card data. |
| 151 | + """ |
| 152 | + # The public agent card is a direct serialization of the agent_card |
| 153 | + # provided at initialization. |
| 154 | + return JSONResponse( |
| 155 | + self.agent_card.model_dump(mode='json', exclude_none=True) |
| 156 | + ) |
| 157 | + |
| 158 | + async def handle_authenticated_agent_card( |
| 159 | + self, request: Request |
| 160 | + ) -> JSONResponse: |
| 161 | + """Hook for per credential agent card response. |
| 162 | +
|
| 163 | + If a dynamic card is needed based on the credentials provided in the request |
| 164 | + override this method and return the customized content. |
| 165 | +
|
| 166 | + Args: |
| 167 | + request: The incoming Starlette Request object. |
| 168 | +
|
| 169 | + Returns: |
| 170 | + A JSONResponse containing the authenticated card. |
| 171 | + """ |
| 172 | + if not self.agent_card.supportsAuthenticatedExtendedCard: |
| 173 | + return JSONResponse( |
| 174 | + '{"detail": "Authenticated card not supported"}', |
| 175 | + status_code=404, |
| 176 | + ) |
| 177 | + return JSONResponse( |
| 178 | + self.agent_card.model_dump(mode='json', exclude_none=True) |
| 179 | + ) |
| 180 | + |
| 181 | + def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: |
| 182 | + routes = { |
| 183 | + ('/v1/message:send', 'POST'): ( |
| 184 | + functools.partial( |
| 185 | + self._handle_request, self.handler.on_message_send |
| 186 | + ), |
| 187 | + ), |
| 188 | + ('/v1/message:stream', 'POST'): ( |
| 189 | + functools.partial( |
| 190 | + self._handle_streaming_request, |
| 191 | + self.handler.on_message_send_stream, |
| 192 | + ), |
| 193 | + ), |
| 194 | + ('/v1/tasks/{id}:subscribe', 'POST'): ( |
| 195 | + functools.partial( |
| 196 | + self._handle_streaming_request, |
| 197 | + self.handler.on_resubscribe_to_task, |
| 198 | + ), |
| 199 | + ), |
| 200 | + ('/v1/tasks/{id}', 'GET'): ( |
| 201 | + functools.partial( |
| 202 | + self._handle_request, self.handler.on_get_task |
| 203 | + ), |
| 204 | + ), |
| 205 | + ('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): ( |
| 206 | + functools.partial( |
| 207 | + self._handle_request, self.handler.get_push_notification |
| 208 | + ), |
| 209 | + ), |
| 210 | + ('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): ( |
| 211 | + functools.partial( |
| 212 | + self._handle_request, self.handler.set_push_notification |
| 213 | + ), |
| 214 | + ), |
| 215 | + ('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): ( |
| 216 | + functools.partial( |
| 217 | + self._handle_request, self.handler.list_push_notifications |
| 218 | + ), |
| 219 | + ), |
| 220 | + ('/v1/tasks', 'GET'): ( |
| 221 | + functools.partial( |
| 222 | + self._handle_request, self.handler.list_tasks |
| 223 | + ), |
| 224 | + ), |
| 225 | + } |
| 226 | + if self.agent_card.supportsAuthenticatedExtendedCard: |
| 227 | + routes['/v1/card'] = (self.handle_authenticated_agent_card, 'GET') |
| 228 | + return routes |
0 commit comments