diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 55758368..97d884a2 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -17,6 +17,23 @@ AServers AService AStarlette AUser +DSNs +EUR +GBP +GVsb +INR +JPY +JSONRPCt +Llm +POSTGRES +RUF +Tful +aconnect +adk +agentic +aio +aiomysql +aproject autouse backticks cla diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py index a73e05c8..4d42ee8c 100644 --- a/src/a2a/server/apps/__init__.py +++ b/src/a2a/server/apps/__init__.py @@ -6,11 +6,17 @@ CallContextBuilder, JSONRPCApplication, ) +from a2a.server.apps.rest import ( + A2ARESTFastAPIApplication, + RESTApplication, +) __all__ = [ 'A2AFastAPIApplication', + 'A2ARESTFastAPIApplication', 'A2AStarletteApplication', 'CallContextBuilder', 'JSONRPCApplication', + 'RESTApplication', ] diff --git a/src/a2a/server/apps/jsonrpc/__init__.py b/src/a2a/server/apps/jsonrpc/__init__.py index ab803d4e..1121fdbc 100644 --- a/src/a2a/server/apps/jsonrpc/__init__.py +++ b/src/a2a/server/apps/jsonrpc/__init__.py @@ -3,7 +3,9 @@ from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication from a2a.server.apps.jsonrpc.jsonrpc_app import ( CallContextBuilder, + DefaultCallContextBuilder, JSONRPCApplication, + StarletteUserProxy, ) from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication @@ -12,5 +14,7 @@ 'A2AFastAPIApplication', 'A2AStarletteApplication', 'CallContextBuilder', + 'DefaultCallContextBuilder', 'JSONRPCApplication', + 'StarletteUserProxy', ] diff --git a/src/a2a/server/apps/rest/__init__.py b/src/a2a/server/apps/rest/__init__.py new file mode 100644 index 00000000..c57b0f38 --- /dev/null +++ b/src/a2a/server/apps/rest/__init__.py @@ -0,0 +1,10 @@ +"""A2A REST Applications.""" + +from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication +from a2a.server.apps.rest.rest_app import RESTApplication + + +__all__ = [ + 'A2ARESTFastAPIApplication', + 'RESTApplication', +] diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py new file mode 100644 index 00000000..76fa887f --- /dev/null +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -0,0 +1,81 @@ +import logging + +from typing import Any + +from fastapi import APIRouter, FastAPI, Request, Response + +from a2a.server.apps.jsonrpc.jsonrpc_app import ( + CallContextBuilder, +) +from a2a.server.apps.rest.rest_app import ( + RESTApplication, +) +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import AgentCard + + +logger = logging.getLogger(__name__) + + +class A2ARESTFastAPIApplication: + """A FastAPI application implementing the A2A protocol server REST endpoints. + + Handles incoming REST requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + def __init__( + self, + agent_card: AgentCard, + http_handler: RequestHandler, + context_builder: CallContextBuilder | None = None, + ): + """Initializes the A2ARESTFastAPIApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + """ + self._handler = RESTApplication( + agent_card=agent_card, + http_handler=http_handler, + context_builder=context_builder, + ) + + def build( + self, + agent_card_url: str = '/.well-known/agent.json', + rpc_url: str = '', + **kwargs: Any, + ) -> FastAPI: + """Builds and returns the FastAPI application instance. + + Args: + agent_card_url: The URL for the agent card endpoint. + rpc_url: The URL for the A2A JSON-RPC endpoint. + extended_agent_card_url: The URL for the authenticated extended agent card endpoint. + **kwargs: Additional keyword arguments to pass to the FastAPI constructor. + + Returns: + A configured FastAPI application instance. + """ + app = FastAPI(**kwargs) + router = APIRouter() + for route, callback in self._handler.routes().items(): + router.add_api_route( + f'{rpc_url}{route[0]}', callback, methods=[route[1]] + ) + + @router.get(f'{rpc_url}{agent_card_url}') + async def get_agent_card(request: Request) -> Response: + return await self._handler._handle_get_agent_card(request) + + app.include_router(router) + return app diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py new file mode 100644 index 00000000..717c6e9f --- /dev/null +++ b/src/a2a/server/apps/rest/rest_app.py @@ -0,0 +1,228 @@ +import functools +import json +import logging +import traceback + +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable +from typing import Any + +from pydantic import ValidationError +from sse_starlette.sse import EventSourceResponse +from starlette.requests import Request +from starlette.responses import JSONResponse + +from a2a.server.apps.jsonrpc import ( + CallContextBuilder, + DefaultCallContextBuilder, +) +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.rest_handler import ( + RESTHandler, +) +from a2a.types import ( + AgentCard, + InternalError, + InvalidRequestError, + JSONParseError, + UnsupportedOperationError, +) +from a2a.utils.errors import MethodNotImplementedError + + +logger = logging.getLogger(__name__) + + +class RESTApplication: + """Base class for A2A REST applications. + + Defines REST requests processors and the routes to attach them too, as well as + manages response generation including Server-Sent Events (SSE). + """ + + def __init__( + self, + agent_card: AgentCard, + http_handler: RequestHandler, + context_builder: CallContextBuilder | None = None, + ): + """Initializes the RESTApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + """ + self.agent_card = agent_card + self.handler = RESTHandler( + agent_card=agent_card, request_handler=http_handler + ) + self._context_builder = context_builder or DefaultCallContextBuilder() + + def _generate_error_response(self, error) -> JSONResponse: + """Creates a JSONResponse for a errors. + + Logs the error based on its type. + + Args: + error: The Error object. + + Returns: + A `JSONResponse` object formatted as a JSON error response. + """ + log_level = ( + logging.ERROR + if isinstance(error, InternalError) + else logging.WARNING + ) + logger.log( + log_level, + 'Request Error: ' + f"Code={error.code}, Message='{error.message}'" + f'{", Data=" + str(error.data) if error.data else ""}', + ) + return JSONResponse( + '{"message": ' + error.message + '}', + status_code=404, + ) + + def _handle_error(self, error: Exception) -> JSONResponse: + traceback.print_exc() + if isinstance(error, MethodNotImplementedError): + return self._generate_error_response(UnsupportedOperationError()) + if isinstance(error, json.decoder.JSONDecodeError): + return self._generate_error_response( + JSONParseError(message=str(error)) + ) + if isinstance(error, ValidationError): + return self._generate_error_response( + InvalidRequestError(data=json.loads(error.json())), + ) + logger.error(f'Unhandled exception: {error}') + return self._generate_error_response(InternalError(message=str(error))) + + async def _handle_request( + self, + method: Callable[[Request, ServerCallContext], Awaitable[str]], + request: Request, + ) -> JSONResponse: + try: + call_context = self._context_builder.build(request) + response = await method(request, call_context) + return JSONResponse(content=response) + except Exception as e: + return self._handle_error(e) + + async def _handle_streaming_request( + self, + method: Callable[[Request, ServerCallContext], AsyncIterator[str]], + request: Request, + ) -> EventSourceResponse: + try: + call_context = self._context_builder.build(request) + + async def event_generator( + stream: AsyncGenerator[str], + ) -> AsyncGenerator[dict[str, str]]: + async for item in stream: + yield {'data': item} + + return EventSourceResponse( + event_generator(method(request, call_context)) + ) + except Exception as e: + # Since the stream has started, we can't return a JSONResponse. + # Instead, we runt the error handling logic (provides logging) + # and reraise the error and let server framework manage + self._handle_error(e) + raise e + + async def _handle_get_agent_card(self, request: Request) -> JSONResponse: + """Handles GET requests for the agent card endpoint. + + Args: + request: The incoming Starlette Request object. + + Returns: + A JSONResponse containing the agent card data. + """ + # The public agent card is a direct serialization of the agent_card + # provided at initialization. + return JSONResponse( + self.agent_card.model_dump(mode='json', exclude_none=True) + ) + + async def handle_authenticated_agent_card( + self, request: Request + ) -> JSONResponse: + """Hook for per credential agent card response. + + If a dynamic card is needed based on the credentials provided in the request + override this method and return the customized content. + + Args: + request: The incoming Starlette Request object. + + Returns: + A JSONResponse containing the authenticated card. + """ + if not self.agent_card.supportsAuthenticatedExtendedCard: + return JSONResponse( + '{"detail": "Authenticated card not supported"}', + status_code=404, + ) + return JSONResponse( + self.agent_card.model_dump(mode='json', exclude_none=True) + ) + + def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: + routes = { + ('/v1/message:send', 'POST'): ( + functools.partial( + self._handle_request, self.handler.on_message_send + ), + ), + ('/v1/message:stream', 'POST'): ( + functools.partial( + self._handle_streaming_request, + self.handler.on_message_send_stream, + ), + ), + ('/v1/tasks/{id}:subscribe', 'POST'): ( + functools.partial( + self._handle_streaming_request, + self.handler.on_resubscribe_to_task, + ), + ), + ('/v1/tasks/{id}', 'GET'): ( + functools.partial( + self._handle_request, self.handler.on_get_task + ), + ), + ('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): ( + functools.partial( + self._handle_request, self.handler.get_push_notification + ), + ), + ('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): ( + functools.partial( + self._handle_request, self.handler.set_push_notification + ), + ), + ('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): ( + functools.partial( + self._handle_request, self.handler.list_push_notifications + ), + ), + ('/v1/tasks', 'GET'): ( + functools.partial( + self._handle_request, self.handler.list_tasks + ), + ), + } + if self.agent_card.supportsAuthenticatedExtendedCard: + routes['/v1/card'] = (self.handle_authenticated_agent_card, 'GET') + return routes diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 9882dc2a..43ebc8e2 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -11,6 +11,7 @@ build_error_response, prepare_response_object, ) +from a2a.server.request_handlers.rest_handler import RESTHandler logger = logging.getLogger(__name__) @@ -40,6 +41,7 @@ def __init__(self, *args, **kwargs): 'DefaultRequestHandler', 'GrpcHandler', 'JSONRPCHandler', + 'RESTHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py new file mode 100644 index 00000000..930318d2 --- /dev/null +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -0,0 +1,319 @@ +import logging + +from collections.abc import AsyncIterable + +from google.protobuf.json_format import MessageToJson, Parse +from starlette.requests import Request + +from a2a.grpc import a2a_pb2 +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + A2AError, + AgentCard, + GetTaskPushNotificationConfigParams, + InternalError, + Task, + TaskIdParams, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskQueryParams, +) +from a2a.utils import proto_utils +from a2a.utils.errors import ServerError +from a2a.utils.helpers import validate +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.SERVER) +class RESTHandler: + """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses. + + This uses the protobuf definitions of the gRPC service as the source of truth. By + doing this, it ensures that this implementation and the gRPC transcoding + (via Envoy) are equivalent. This handler should be used if using the gRPC handler + with Envoy is not feasible for a given deployment solution. Use this handler + and a related application if you desire to ONLY server the RESTful API. + """ + + def __init__( + self, + agent_card: AgentCard, + request_handler: RequestHandler, + ): + """Initializes the RESTHandler. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The underlying `RequestHandler` instance to delegate requests to. + """ + self.agent_card = agent_card + self.request_handler = request_handler + + async def on_message_send( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'message/send' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `str` containing the JSON result (Task or Message) + + Raises: + A2AError if a `ServerError` is raised by the handler. + """ + # TODO: Wrap in error handler to return error states + try: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + params, + ) + task_or_message = await self.request_handler.on_message_send( + a2a_request, context + ) + return MessageToJson( + proto_utils.ToProto.task_or_message(task_or_message) + ) + except ServerError as e: + raise A2AError(error=e.error if e.error else InternalError()) from e + + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_message_send_stream( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> AsyncIterable[str]: + """Handles the 'message/stream' REST method. + + Yields response objects as they are produced by the underlying handler's stream. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Yields: + `str` objects containing streaming events + (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON + Raises: + `A2AError` + """ + try: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + params, + ) + async for event in self.request_handler.on_message_send_stream( + a2a_request, context + ): + response = proto_utils.ToProto.stream_response(event) + yield MessageToJson(response) + except ServerError as e: + raise A2AError(error=e.error if e.error else InternalError()) from e + return + + async def on_cancel_task( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'tasks/cancel' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `str` containing the updated Task in JSON format + Raises: + A2AError. + """ + try: + task_id = request.path_params['id'] + task = await self.request_handler.on_cancel_task( + TaskIdParams(id=task_id), context + ) + if task: + return MessageToJson(proto_utils.ToProto.task(task)) + raise ServerError(error=TaskNotFoundError()) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError(), + ) from e + + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_resubscribe_to_task( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> AsyncIterable[str]: + """Handles the 'tasks/resubscribe' REST method. + + Yields response objects as they are produced by the underlying handler's stream. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Yields: + `str` containing streaming events in JSON format + + Raises: + A A2AError if an error is encountered + """ + try: + task_id = request.path_params['id'] + async for event in self.request_handler.on_resubscribe_to_task( + TaskIdParams(id=task_id), context + ): + yield ( + MessageToJson(proto_utils.ToProto.stream_response(event)) + ) + except ServerError as e: + raise A2AError(error=e.error if e.error else InternalError()) from e + + async def get_push_notification( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'tasks/pushNotificationConfig/get' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `str` containing the config as JSON + Raises: + A2AError. + """ + try: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + if push_id: + params = GetTaskPushNotificationConfigParams( + id=task_id, push_id=push_id + ) + else: + params = TaskIdParams['id'] + config = ( + await self.request_handler.on_get_task_push_notification_config( + params, context + ) + ) + return MessageToJson( + proto_utils.ToProto.task_push_notification_config(config) + ) + except ServerError as e: + raise A2AError(error=e.error if e.error else InternalError()) + + @validate( + lambda self: self.agent_card.capabilities.pushNotifications, + 'Push notifications are not supported by the agent', + ) + async def set_push_notification( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'tasks/pushNotificationConfig/set' REST method. + + Requires the agent to support push notifications. + + Args: + request: The incoming `TaskPushNotificationConfig` object. + context: Context provided by the server. + + Returns: + A `str` containing the config as JSON object. + + Raises: + ServerError: If push notifications are not supported by the agent + (due to the `@validate` decorator), A2AError if processing error is + found. + """ + try: + task_id = request.path_params['id'] + body = await request.body() + params = a2a_pb2.TaskPushNotificationConfig() + Parse(body, params) + params = TaskPushNotificationConfig.validate_model(body) + a2a_request = ( + proto_utils.FromProto.task_push_notification_config( + params, + ), + ) + config = ( + await self.request_handler.on_set_task_push_notification_config( + a2a_request, context + ) + ) + return MessageToJson( + proto_utils.ToProto.task_push_notification_config(config) + ) + except ServerError as e: + raise A2AError(error=e.error if e.error else InternalError()) from e + + async def on_get_task( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'v1/tasks/{id}' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `Task` object containing the Task. + + Raises: + A2AError + """ + try: + task_id = request.path_params['id'] + historyLength = None + if 'historyLength' in request.query_params: + historyLength = request.query_params['historyLength'] + params = TaskQueryParams(id=task_id, historyLength=historyLength) + task = await self.request_handler.on_get_task(params, context) + if task: + return MessageToJson(proto_utils.ToProto.task(task)) + raise ServerError(error=TaskNotFoundError()) + except ServerError as e: + raise A2AError(error=e.error if e.error else InternalError()) from e + + async def list_push_notifications( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> list[TaskPushNotificationConfig]: + raise NotImplementedError('list notifications not implemented') + + async def list_tasks( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> list[Task]: + raise NotImplementedError('list tasks not implemented')