From 03a0fbe6228d9b0c907820695359640fab0415aa Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 11 Jul 2025 15:17:09 +0000 Subject: [PATCH 1/4] feat: [WIP] Add dedicated RESTful handler --- src/a2a/server/apps/__init__.py | 6 + src/a2a/server/apps/jsonrpc/__init__.py | 4 + src/a2a/server/apps/rest/__init__.py | 9 + src/a2a/server/apps/rest/fastapi_app.py | 83 +++++ src/a2a/server/apps/rest/rest_app.py | 227 +++++++++++++ src/a2a/server/request_handlers/__init__.py | 2 + .../server/request_handlers/rest_handler.py | 318 ++++++++++++++++++ 7 files changed, 649 insertions(+) create mode 100644 src/a2a/server/apps/rest/__init__.py create mode 100644 src/a2a/server/apps/rest/fastapi_app.py create mode 100644 src/a2a/server/apps/rest/rest_app.py create mode 100644 src/a2a/server/request_handlers/rest_handler.py diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py index a73e05c8..4d42ee8c 100644 --- a/src/a2a/server/apps/__init__.py +++ b/src/a2a/server/apps/__init__.py @@ -6,11 +6,17 @@ CallContextBuilder, JSONRPCApplication, ) +from a2a.server.apps.rest import ( + A2ARESTFastAPIApplication, + RESTApplication, +) __all__ = [ 'A2AFastAPIApplication', + 'A2ARESTFastAPIApplication', 'A2AStarletteApplication', 'CallContextBuilder', 'JSONRPCApplication', + 'RESTApplication', ] diff --git a/src/a2a/server/apps/jsonrpc/__init__.py b/src/a2a/server/apps/jsonrpc/__init__.py index ab803d4e..b322f0ef 100644 --- a/src/a2a/server/apps/jsonrpc/__init__.py +++ b/src/a2a/server/apps/jsonrpc/__init__.py @@ -4,6 +4,8 @@ from a2a.server.apps.jsonrpc.jsonrpc_app import ( CallContextBuilder, JSONRPCApplication, + StarletteUserProxy, + DefaultCallContextBuilder, ) from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication @@ -13,4 +15,6 @@ 'A2AStarletteApplication', 'CallContextBuilder', 'JSONRPCApplication', + 'StarletteUserProxy', + 'DefaultCallContextBuilder', ] diff --git a/src/a2a/server/apps/rest/__init__.py b/src/a2a/server/apps/rest/__init__.py new file mode 100644 index 00000000..81ee7b7a --- /dev/null +++ b/src/a2a/server/apps/rest/__init__.py @@ -0,0 +1,9 @@ +"""A2A REST Applications.""" + +from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication +from a2a.server.apps.rest.rest_app import RESTApplication + +__all__ = [ + 'A2ARESTFastAPIApplication', + 'RESTApplication', +] diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py new file mode 100644 index 00000000..c6356560 --- /dev/null +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -0,0 +1,83 @@ +import logging + +from typing import Any + +from fastapi import FastAPI, Request, Response, APIRouter + +from a2a.server.apps.jsonrpc.jsonrpc_app import ( + CallContextBuilder, +) +from a2a.server.apps.rest.rest_app import ( + RESTApplication, +) +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import AgentCard + + +logger = logging.getLogger(__name__) + + +class A2ARESTFastAPIApplication: + """A FastAPI application implementing the A2A protocol server REST endpoints. + + Handles incoming REST requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + def __init__( + self, + agent_card: AgentCard, + http_handler: RequestHandler, + context_builder: CallContextBuilder | None = None, + ): + """Initializes the A2ARESTFastAPIApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + """ + self._handler = RESTApplication( + agent_card=agent_card, + http_handler=http_handler, + context_builder=context_builder, + ) + + def build( + self, + agent_card_url: str = '/.well-known/agent.json', + rpc_url: str = '', + **kwargs: Any, + ) -> FastAPI: + """Builds and returns the FastAPI application instance. + + Args: + agent_card_url: The URL for the agent card endpoint. + rpc_url: The URL for the A2A JSON-RPC endpoint. + extended_agent_card_url: The URL for the authenticated extended agent card endpoint. + **kwargs: Additional keyword arguments to pass to the FastAPI constructor. + + Returns: + A configured FastAPI application instance. + """ + app = FastAPI(**kwargs) + router = APIRouter() + for route, callback in self._handler.routes().items(): + router.add_api_route( + f'{rpc_url}{route}', + callback[0], + methods=[callback[1]] + ) + + @router.get(f'{rpc_url}{agent_card_url}') + async def get_agent_card(request: Request) -> Response: + return await self._handle_get_agent_card(request) + + app.include_router(router) + return app diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py new file mode 100644 index 00000000..b4b9d3ad --- /dev/null +++ b/src/a2a/server/apps/rest/rest_app.py @@ -0,0 +1,227 @@ +import contextlib +import json +import logging +import traceback +import functools + +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable +from typing import Any, Tuple, Callable +from fastapi import FastAPI +from pydantic import BaseModel, ValidationError + +from sse_starlette.sse import EventSourceResponse +from starlette.applications import Starlette +from starlette.authentication import BaseUser +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from a2a.auth.user import UnauthenticatedUser +from a2a.auth.user import User as A2AUser +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.rest_handler import ( + RESTHandler, +) +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + A2AError, + AgentCard, + UnsupportedOperationError, + InternalError, +) +from a2a.utils.errors import MethodNotImplementedError +from a2a.server.apps.jsonrpc import ( + CallContextBuilder, + StarletteUserProxy, + DefaultCallContextBuilder +) + + +logger = logging.getLogger(__name__) + + +class RESTApplication: + """Base class for A2A REST applications. + + Defines REST requests processors and the routes to attach them too, as well as + manages response generation including Server-Sent Events (SSE). + """ + + def __init__( + self, + agent_card: AgentCard, + http_handler: RequestHandler, + context_builder: CallContextBuilder | None = None, + ): + """Initializes the RESTApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + """ + self.agent_card = agent_card + self.handler = RESTHandler( + agent_card=agent_card, request_handler=http_handler + ) + self._context_builder = context_builder or DefaultCallContextBuilder() + + def _generate_error_response(self, error) -> JSONResponse: + """Creates a JSONResponse for a errors. + + Logs the error based on its type. + + Args: + error: The Error object. + + Returns: + A `JSONResponse` object formatted as a JSON error response. + """ + log_level = ( + logging.ERROR + if isinstance(error, InternalError) + else logging.WARNING + ) + logger.log( + log_level, + 'Request Error: ' + f"Code={error.code}, Message='{error.message}'" + f'{", Data=" + str(error.data) if error.data else ""}', + ) + return JSONResponse( + '{"message": ' + error.message + '}', + status_code=404, + ) + + def _handle_error(self, error: Exception) -> JSONResponse: + traceback.print_exc() + if isinstance(error, MethodNotImplementedError): + return self._generate_error_response(UnsupportedOperationError()) + elif isinstance(error, json.decoder.JSONDecodeError): + return self._generate_error_response( + JSONParseError(message=str(error)) + ) + elif isinstance(error, ValidationError): + return self._generate_error_response( + InvalidRequestError(data=json.loads(error.json())), + ) + logger.error(f'Unhandled exception: {error}') + return self._generate_error_response( + InternalError(message=str(error)) + ) + + async def _handle_request( + self, + method: Callable[[Request, ServerCallContext], Awaitable[str]], + request: Request + ) -> JSONResponse: + try: + call_context = self._context_builder.build(request) + response = await method(request, call_context) + return JSONResponse(content=response) + except Exception as e: + return self._handle_error(e) + + async def _handle_streaming_request( + self, + method: Callable[[Request, ServerCallContext], AsyncIterator[str]], + request: Request + ) -> EventSourceResponse: + try: + call_context = self._context_builder.build(request) + async def event_generator( + stream: AsyncGenerator[str], + ) -> AsyncGenerator[dict[str, str]]: + async for item in stream: + yield {'data': item} + return EventSourceResponse(event_generator(method(request, call_context))) + except Exception as e: + return self._handle_error(e) + + + async def _handle_get_agent_card(self, request: Request) -> JSONResponse: + """Handles GET requests for the agent card endpoint. + + Args: + request: The incoming Starlette Request object. + + Returns: + A JSONResponse containing the agent card data. + """ + # The public agent card is a direct serialization of the agent_card + # provided at initialization. + return JSONResponse( + self.agent_card.model_dump(mode='json', exclude_none=True) + ) + + async def handle_authenticated_agent_card(self, request: Request) -> JSONResponse: + """Hook for per credential agent card response. + + If a dynamic card is needed based on the credentials provided in the request + override this method and return the customized content. + + Args: + request: The incoming Starlette Request object. + + Returns: + A JSONResponse containing the authenticated card. + """ + if not self.agent_card.supportsAuthenticatedExtendedCard: + return JSONResponse( + '{"detail": "Authenticated card not supported"}', status_code=404 + ) + return JSONResponse( + self.agent_card.model_dump(mode='json', exclude_none=True) + ) + + def routes(self) -> dict[str, Tuple[Callable[[Request],Any], str]]: + routes = { + '/v1/message:send': ( + functools.partial( + self._handle_request, + self.handler.on_message_send), + 'POST'), + '/v1/message:stream': ( + functools.partial( + self._handle_streaming_request, + self.handler.on_message_send_stream), + 'POST'), + '/v1/tasks/{id}:subscribe': ( + functools.partial( + self._handle_streaming_request, + self.handler.on_resubscribe_to_task), + 'POST'), + '/v1/tasks/{id}': ( + functools.partial( + self._handle_request, + self.handler.on_get_task), + 'GET'), + '/v1/tasks/{id}/pushNotificationConfigs/{push_id}': ( + functools.partial( + self._handle_request, + self.handler.get_push_notification), + 'GET'), + '/v1/tasks/{id}/pushNotificationConfigs': ( + functools.partial( + self._handle_request, + self.handler.set_push_notification), + 'POST'), + '/v1/tasks/{id}/pushNotificationConfigs': ( + functools.partial( + self._handle_request, + self.handler.list_push_notifications), + 'GET'), + '/v1/tasks': ( + functools.partial( + self._handle_request, + self.handler.list_tasks), + 'GET'), + } + if self.agent_card.supportsAuthenticatedExtendedCard: + routes['/v1/card'] = ( + self.handle_authenticated_agent_card, + 'GET') + return routes diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 8cf2fe8c..20cc3815 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -6,6 +6,7 @@ from a2a.server.request_handlers.grpc_handler import GrpcHandler from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.rest_handler import RESTHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, prepare_response_object, @@ -17,6 +18,7 @@ 'GrpcHandler', 'JSONRPCHandler', 'RequestHandler', + 'RESTHandler', 'build_error_response', 'prepare_response_object', ] diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py new file mode 100644 index 00000000..6f6167a5 --- /dev/null +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -0,0 +1,318 @@ +import logging + +from collections.abc import AsyncIterable +from starlette.requests import Request +from pydantic import BaseModel, Field, RootModel + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + A2AError, + AgentCard, + InternalError, + Message, + Task, + TaskArtifactUpdateEvent, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskStatusUpdateEvent, + GetTaskPushNotificationConfigParams, + MessageSendParams, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, +) +from a2a.utils.errors import ServerError +from a2a.utils.helpers import validate +from a2a.utils.telemetry import SpanKind, trace_class +from a2a.grpc import a2a_pb2 +from a2a.utils import proto_utils +from google.protobuf.json_format import Parse, MessageToJson + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.SERVER) +class RESTHandler: + """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses. + + This uses the protobuf definitions of the gRPC service as the source of truth. By + doing this, it ensures that this implementation and the gRPC transcoding + (via Envoy) are equivalent. This handler should be used if using the gRPC handler + with Envoy is not feasible for a given deployment solution. Use this handler + and a related application if you desire to ONLY server the RESTful API. + """ + + def __init__( + self, + agent_card: AgentCard, + request_handler: RequestHandler, + ): + """Initializes the RESTHandler. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The underlying `RequestHandler` instance to delegate requests to. + """ + self.agent_card = agent_card + self.request_handler = request_handler + + async def on_message_send( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'message/send' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `str` containing the JSON result (Task or Message) + Raises: + A2AError if a `ServerError` is raised by the handler. + """ + # TODO: Wrap in error handler to return error states + try: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + params, + ) + task_or_message = await self.request_handler.on_message_send( + a2a_request, context + ) + return MessageToJson(proto_utils.ToProto.task_or_message(task_or_message)) + except ServerError as e: + return A2AError( + error=e.error if e.error else InternalError() + ) + + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_message_send_stream( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> AsyncIterable[str]: + """Handles the 'message/stream' REST method. + + Yields response objects as they are produced by the underlying handler's stream. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Yields: + `str` objects containing streaming events + (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON + Raises: + `A2AError` + """ + try: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + params, + ) + async for event in self.request_handler.on_message_send_stream( + a2a_request, context + ): + response = proto_utils.ToProto.stream_response(event) + yield MessageToJson(response) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError() + ) from e + return + + async def on_cancel_task( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'tasks/cancel' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `str` containing the updated Task in JSON format + Raises: + A2AError. + """ + try: + task_id = request.path_params['id'] + task = await self.request_handler.on_cancel_task( + TaskIdParams(id=task_id), context + ) + if task: + return MessageToJson(proto_utils.ToProto.task(task)) + raise ServerError(error=TaskNotFoundError()) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError(), + ) from e + + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_resubscribe_to_task( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> AsyncIterable[str]: + """Handles the 'tasks/resubscribe' REST method. + + Yields response objects as they are produced by the underlying handler's stream. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Yields: + `str` containing streaming events in JSON format + + Raises: + A A2AError if an error is encountered + """ + try: + task_id = request.path_params['id'] + async for event in self.request_handler.on_resubscribe_to_task( + TaskIdParams(id=task_id), context + ): + yield(MessageToJson(proto_utils.ToProto.stream_response(event))) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError() + ) from e + + async def get_push_notification( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'tasks/pushNotificationConfig/get' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `str` containing the config as JSON + Raises: + A2AError. + """ + try: + task_id = request.path_params['task_id'] + push_id = request.path_params['push_id'] + if push_id: + params = GetTaskPushNotificationConfigParams(id=task_id, push_id=push_id) + else: + params = TaskIdParams['task_id'] + config = await self.request_handler.on_get_task_push_notification_config( + params, context + ) + return MessageToJson( + proto_utils.ToProto.task_push_notification_config(config) + ) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError() + ) + + @validate( + lambda self: self.agent_card.capabilities.pushNotifications, + 'Push notifications are not supported by the agent', + ) + async def set_push_notification( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'tasks/pushNotificationConfig/set' REST method. + + Requires the agent to support push notifications. + + Args: + request: The incoming `TaskPushNotificationConfig` object. + context: Context provided by the server. + + Returns: + A `str` containing the config as JSON object. + + Raises: + ServerError: If push notifications are not supported by the agent + (due to the `@validate` decorator), A2AError if processing error is + found. + """ + try: + task_id = request.path_params['id'] + body = await request.body() + params = TaskPushNotificationConfig.validate_model(body) + config = await self.request_handler.on_set_task_push_notification_config( + params, context + ) + return MessageToJson( + proto_utils.ToProto.task_push_notification_config(config) + ) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError() + ) from e + + async def on_get_task( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'v1/tasks/{id}' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `Task` object containing the Task. + + Raises: + A2AError + """ + try: + task_id = request.path_params['id'] + historyLength = None + if 'historyLength' in request.query_params: + history_length = request.query_params['historyLength'] + params = TaskQueryParams(id=task_id, historyLength=historyLength) + task = await self.request_handler.on_get_task(params, context) + if task: + return MessageToJson(proto_utils.ToProto.task(task)) + raise ServerError(error=TaskNotFoundError()) + except ServerError as e: + raise A2AError( + id=request.id, error=e.error if e.error else InternalError() + ) from e + + async def list_push_notifications( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> list[TaskPushNotificationConfig]: + raise NotImplementedError("list notifications not implemented") + + async def list_tasks( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> list[Task]: + raise NotImplementedError("list tasks not implemented") From 661db107587d8e3b91b7d26fe1f2d7102230dcc1 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 11 Jul 2025 18:02:56 +0000 Subject: [PATCH 2/4] Address gemini comments --- src/a2a/server/apps/rest/fastapi_app.py | 8 ++-- src/a2a/server/apps/rest/rest_app.py | 42 +++++++++++-------- .../server/request_handlers/rest_handler.py | 19 +++++---- 3 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index c6356560..28922a33 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -70,14 +70,14 @@ def build( router = APIRouter() for route, callback in self._handler.routes().items(): router.add_api_route( - f'{rpc_url}{route}', - callback[0], - methods=[callback[1]] + f'{rpc_url}{route[0]}', + callback, + methods=[route[1]] ) @router.get(f'{rpc_url}{agent_card_url}') async def get_agent_card(request: Request) -> Response: - return await self._handle_get_agent_card(request) + return await self._handler._handle_get_agent_card(request) app.include_router(router) return app diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index b4b9d3ad..247d3f2e 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -26,8 +26,10 @@ from a2a.types import ( A2AError, AgentCard, + JSONParseError, UnsupportedOperationError, InternalError, + InvalidRequestError, ) from a2a.utils.errors import MethodNotImplementedError from a2a.server.apps.jsonrpc import ( @@ -139,7 +141,11 @@ async def event_generator( yield {'data': item} return EventSourceResponse(event_generator(method(request, call_context))) except Exception as e: - return self._handle_error(e) + # Since the stream has started, we can't return a JSONResponse. + # Instead, we runt the error handling logic (provides logging) + # and reraise the error and let server framework manage + self._handle_error(e) + raise e async def _handle_get_agent_card(self, request: Request) -> JSONResponse: @@ -177,48 +183,48 @@ async def handle_authenticated_agent_card(self, request: Request) -> JSONRespons self.agent_card.model_dump(mode='json', exclude_none=True) ) - def routes(self) -> dict[str, Tuple[Callable[[Request],Any], str]]: + def routes(self) -> dict[Tuple[str, str], Callable[[Request],Any]]: routes = { - '/v1/message:send': ( + ('/v1/message:send', 'POST'): ( functools.partial( self._handle_request, self.handler.on_message_send), - 'POST'), - '/v1/message:stream': ( + ), + ('/v1/message:stream', 'POST'): ( functools.partial( self._handle_streaming_request, self.handler.on_message_send_stream), - 'POST'), - '/v1/tasks/{id}:subscribe': ( + ), + ('/v1/tasks/{id}:subscribe', 'POST'): ( functools.partial( self._handle_streaming_request, self.handler.on_resubscribe_to_task), - 'POST'), - '/v1/tasks/{id}': ( + ), + ('/v1/tasks/{id}', 'GET'): ( functools.partial( self._handle_request, self.handler.on_get_task), - 'GET'), - '/v1/tasks/{id}/pushNotificationConfigs/{push_id}': ( + ), + ('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): ( functools.partial( self._handle_request, self.handler.get_push_notification), - 'GET'), - '/v1/tasks/{id}/pushNotificationConfigs': ( + ), + ('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): ( functools.partial( self._handle_request, self.handler.set_push_notification), - 'POST'), - '/v1/tasks/{id}/pushNotificationConfigs': ( + ), + ('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): ( functools.partial( self._handle_request, self.handler.list_push_notifications), - 'GET'), - '/v1/tasks': ( + ), + ('/v1/tasks', 'GET'): ( functools.partial( self._handle_request, self.handler.list_tasks), - 'GET'), + ), } if self.agent_card.supportsAuthenticatedExtendedCard: routes['/v1/card'] = ( diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 6f6167a5..6078180f 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -88,9 +88,9 @@ async def on_message_send( ) return MessageToJson(proto_utils.ToProto.task_or_message(task_or_message)) except ServerError as e: - return A2AError( + raise A2AError( error=e.error if e.error else InternalError() - ) + ) from e @validate( lambda self: self.agent_card.capabilities.streaming, @@ -214,12 +214,12 @@ async def get_push_notification( A2AError. """ try: - task_id = request.path_params['task_id'] + task_id = request.path_params['id'] push_id = request.path_params['push_id'] if push_id: params = GetTaskPushNotificationConfigParams(id=task_id, push_id=push_id) else: - params = TaskIdParams['task_id'] + params = TaskIdParams['id'] config = await self.request_handler.on_get_task_push_notification_config( params, context ) @@ -259,9 +259,14 @@ async def set_push_notification( try: task_id = request.path_params['id'] body = await request.body() + params = a2a_pb2.TaskPushNotificationConfig() + Parse(body, params) params = TaskPushNotificationConfig.validate_model(body) + a2a_request = proto_utils.FromProto.task_push_notification_config( + params, + ), config = await self.request_handler.on_set_task_push_notification_config( - params, context + a2a_request, context ) return MessageToJson( proto_utils.ToProto.task_push_notification_config(config) @@ -292,7 +297,7 @@ async def on_get_task( task_id = request.path_params['id'] historyLength = None if 'historyLength' in request.query_params: - history_length = request.query_params['historyLength'] + historyLength = request.query_params['historyLength'] params = TaskQueryParams(id=task_id, historyLength=historyLength) task = await self.request_handler.on_get_task(params, context) if task: @@ -300,7 +305,7 @@ async def on_get_task( raise ServerError(error=TaskNotFoundError()) except ServerError as e: raise A2AError( - id=request.id, error=e.error if e.error else InternalError() + error=e.error if e.error else InternalError() ) from e async def list_push_notifications( From f7bbfc4c1fe3109d9c8255065da90fb3c36d7a55 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 11 Jul 2025 18:06:51 +0000 Subject: [PATCH 3/4] add to allowlist --- .github/actions/spelling/allow.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 6f8229ad..b70e3ad8 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -70,5 +70,6 @@ sse tagwords taskupdate testuuid +Tful typeerror vulnz From 248205934d4b2c670caea46ac8d8f8591dc278a1 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Tue, 15 Jul 2025 17:19:14 +0100 Subject: [PATCH 4/4] Formatting --- .github/actions/spelling/allow.txt | 2 +- src/a2a/server/apps/jsonrpc/__init__.py | 4 +- src/a2a/server/apps/rest/__init__.py | 1 + src/a2a/server/apps/rest/fastapi_app.py | 8 +- src/a2a/server/apps/rest/rest_app.py | 95 +++++++++---------- src/a2a/server/request_handlers/__init__.py | 4 +- .../server/request_handlers/rest_handler.py | 80 ++++++++-------- .../test_default_request_handler.py | 12 +-- 8 files changed, 98 insertions(+), 108 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index b70e3ad8..f30ee12a 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -20,6 +20,7 @@ JSONRPCt Llm POSTGRES RUF +Tful aconnect adk agentic @@ -70,6 +71,5 @@ sse tagwords taskupdate testuuid -Tful typeerror vulnz diff --git a/src/a2a/server/apps/jsonrpc/__init__.py b/src/a2a/server/apps/jsonrpc/__init__.py index b322f0ef..1121fdbc 100644 --- a/src/a2a/server/apps/jsonrpc/__init__.py +++ b/src/a2a/server/apps/jsonrpc/__init__.py @@ -3,9 +3,9 @@ from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication from a2a.server.apps.jsonrpc.jsonrpc_app import ( CallContextBuilder, + DefaultCallContextBuilder, JSONRPCApplication, StarletteUserProxy, - DefaultCallContextBuilder, ) from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication @@ -14,7 +14,7 @@ 'A2AFastAPIApplication', 'A2AStarletteApplication', 'CallContextBuilder', + 'DefaultCallContextBuilder', 'JSONRPCApplication', 'StarletteUserProxy', - 'DefaultCallContextBuilder', ] diff --git a/src/a2a/server/apps/rest/__init__.py b/src/a2a/server/apps/rest/__init__.py index 81ee7b7a..c57b0f38 100644 --- a/src/a2a/server/apps/rest/__init__.py +++ b/src/a2a/server/apps/rest/__init__.py @@ -3,6 +3,7 @@ from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication from a2a.server.apps.rest.rest_app import RESTApplication + __all__ = [ 'A2ARESTFastAPIApplication', 'RESTApplication', diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index 28922a33..76fa887f 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -2,7 +2,7 @@ from typing import Any -from fastapi import FastAPI, Request, Response, APIRouter +from fastapi import APIRouter, FastAPI, Request, Response from a2a.server.apps.jsonrpc.jsonrpc_app import ( CallContextBuilder, @@ -70,14 +70,12 @@ def build( router = APIRouter() for route, callback in self._handler.routes().items(): router.add_api_route( - f'{rpc_url}{route[0]}', - callback, - methods=[route[1]] + f'{rpc_url}{route[0]}', callback, methods=[route[1]] ) @router.get(f'{rpc_url}{agent_card_url}') async def get_agent_card(request: Request) -> Response: - return await self._handler._handle_get_agent_card(request) + return await self._handler._handle_get_agent_card(request) app.include_router(router) return app diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index 247d3f2e..717c6e9f 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -1,42 +1,33 @@ -import contextlib +import functools import json import logging import traceback -import functools -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable -from typing import Any, Tuple, Callable -from fastapi import FastAPI -from pydantic import BaseModel, ValidationError +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable +from typing import Any +from pydantic import ValidationError from sse_starlette.sse import EventSourceResponse -from starlette.applications import Starlette -from starlette.authentication import BaseUser from starlette.requests import Request -from starlette.responses import JSONResponse, Response +from starlette.responses import JSONResponse -from a2a.auth.user import UnauthenticatedUser -from a2a.auth.user import User as A2AUser +from a2a.server.apps.jsonrpc import ( + CallContextBuilder, + DefaultCallContextBuilder, +) from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.rest_handler import ( RESTHandler, ) -from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import ( - A2AError, AgentCard, - JSONParseError, - UnsupportedOperationError, InternalError, InvalidRequestError, + JSONParseError, + UnsupportedOperationError, ) from a2a.utils.errors import MethodNotImplementedError -from a2a.server.apps.jsonrpc import ( - CallContextBuilder, - StarletteUserProxy, - DefaultCallContextBuilder -) logger = logging.getLogger(__name__) @@ -102,23 +93,21 @@ def _handle_error(self, error: Exception) -> JSONResponse: traceback.print_exc() if isinstance(error, MethodNotImplementedError): return self._generate_error_response(UnsupportedOperationError()) - elif isinstance(error, json.decoder.JSONDecodeError): + if isinstance(error, json.decoder.JSONDecodeError): return self._generate_error_response( JSONParseError(message=str(error)) ) - elif isinstance(error, ValidationError): + if isinstance(error, ValidationError): return self._generate_error_response( InvalidRequestError(data=json.loads(error.json())), ) logger.error(f'Unhandled exception: {error}') - return self._generate_error_response( - InternalError(message=str(error)) - ) + return self._generate_error_response(InternalError(message=str(error))) async def _handle_request( self, method: Callable[[Request, ServerCallContext], Awaitable[str]], - request: Request + request: Request, ) -> JSONResponse: try: call_context = self._context_builder.build(request) @@ -130,16 +119,20 @@ async def _handle_request( async def _handle_streaming_request( self, method: Callable[[Request, ServerCallContext], AsyncIterator[str]], - request: Request + request: Request, ) -> EventSourceResponse: try: call_context = self._context_builder.build(request) + async def event_generator( stream: AsyncGenerator[str], ) -> AsyncGenerator[dict[str, str]]: async for item in stream: yield {'data': item} - return EventSourceResponse(event_generator(method(request, call_context))) + + return EventSourceResponse( + event_generator(method(request, call_context)) + ) except Exception as e: # Since the stream has started, we can't return a JSONResponse. # Instead, we runt the error handling logic (provides logging) @@ -147,7 +140,6 @@ async def event_generator( self._handle_error(e) raise e - async def _handle_get_agent_card(self, request: Request) -> JSONResponse: """Handles GET requests for the agent card endpoint. @@ -163,7 +155,9 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse: self.agent_card.model_dump(mode='json', exclude_none=True) ) - async def handle_authenticated_agent_card(self, request: Request) -> JSONResponse: + async def handle_authenticated_agent_card( + self, request: Request + ) -> JSONResponse: """Hook for per credential agent card response. If a dynamic card is needed based on the credentials provided in the request @@ -177,57 +171,58 @@ async def handle_authenticated_agent_card(self, request: Request) -> JSONRespons """ if not self.agent_card.supportsAuthenticatedExtendedCard: return JSONResponse( - '{"detail": "Authenticated card not supported"}', status_code=404 + '{"detail": "Authenticated card not supported"}', + status_code=404, ) return JSONResponse( self.agent_card.model_dump(mode='json', exclude_none=True) ) - def routes(self) -> dict[Tuple[str, str], Callable[[Request],Any]]: + def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: routes = { ('/v1/message:send', 'POST'): ( functools.partial( - self._handle_request, - self.handler.on_message_send), + self._handle_request, self.handler.on_message_send + ), ), ('/v1/message:stream', 'POST'): ( functools.partial( self._handle_streaming_request, - self.handler.on_message_send_stream), + self.handler.on_message_send_stream, + ), ), ('/v1/tasks/{id}:subscribe', 'POST'): ( functools.partial( self._handle_streaming_request, - self.handler.on_resubscribe_to_task), + self.handler.on_resubscribe_to_task, + ), ), ('/v1/tasks/{id}', 'GET'): ( functools.partial( - self._handle_request, - self.handler.on_get_task), + self._handle_request, self.handler.on_get_task + ), ), ('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): ( functools.partial( - self._handle_request, - self.handler.get_push_notification), + self._handle_request, self.handler.get_push_notification + ), ), ('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): ( functools.partial( - self._handle_request, - self.handler.set_push_notification), + self._handle_request, self.handler.set_push_notification + ), ), ('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): ( functools.partial( - self._handle_request, - self.handler.list_push_notifications), + self._handle_request, self.handler.list_push_notifications + ), ), ('/v1/tasks', 'GET'): ( functools.partial( - self._handle_request, - self.handler.list_tasks), + self._handle_request, self.handler.list_tasks + ), ), } if self.agent_card.supportsAuthenticatedExtendedCard: - routes['/v1/card'] = ( - self.handle_authenticated_agent_card, - 'GET') + routes['/v1/card'] = (self.handle_authenticated_agent_card, 'GET') return routes diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 20cc3815..f3f8e020 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -6,19 +6,19 @@ from a2a.server.request_handlers.grpc_handler import GrpcHandler from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.rest_handler import RESTHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, prepare_response_object, ) +from a2a.server.request_handlers.rest_handler import RESTHandler __all__ = [ 'DefaultRequestHandler', 'GrpcHandler', 'JSONRPCHandler', - 'RequestHandler', 'RESTHandler', + 'RequestHandler', 'build_error_response', 'prepare_response_object', ] diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 6078180f..930318d2 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -1,33 +1,28 @@ import logging from collections.abc import AsyncIterable + +from google.protobuf.json_format import MessageToJson, Parse from starlette.requests import Request -from pydantic import BaseModel, Field, RootModel +from a2a.grpc import a2a_pb2 from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import ( A2AError, AgentCard, + GetTaskPushNotificationConfigParams, InternalError, - Message, Task, - TaskArtifactUpdateEvent, - TaskNotFoundError, - TaskPushNotificationConfig, - TaskStatusUpdateEvent, - GetTaskPushNotificationConfigParams, - MessageSendParams, TaskIdParams, + TaskNotFoundError, TaskPushNotificationConfig, TaskQueryParams, ) +from a2a.utils import proto_utils from a2a.utils.errors import ServerError from a2a.utils.helpers import validate from a2a.utils.telemetry import SpanKind, trace_class -from a2a.grpc import a2a_pb2 -from a2a.utils import proto_utils -from google.protobuf.json_format import Parse, MessageToJson logger = logging.getLogger(__name__) @@ -71,6 +66,7 @@ async def on_message_send( Returns: A `str` containing the JSON result (Task or Message) + Raises: A2AError if a `ServerError` is raised by the handler. """ @@ -86,11 +82,11 @@ async def on_message_send( task_or_message = await self.request_handler.on_message_send( a2a_request, context ) - return MessageToJson(proto_utils.ToProto.task_or_message(task_or_message)) + return MessageToJson( + proto_utils.ToProto.task_or_message(task_or_message) + ) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e + raise A2AError(error=e.error if e.error else InternalError()) from e @validate( lambda self: self.agent_card.capabilities.streaming, @@ -129,9 +125,7 @@ async def on_message_send_stream( response = proto_utils.ToProto.stream_response(event) yield MessageToJson(response) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e + raise A2AError(error=e.error if e.error else InternalError()) from e return async def on_cancel_task( @@ -191,11 +185,11 @@ async def on_resubscribe_to_task( async for event in self.request_handler.on_resubscribe_to_task( TaskIdParams(id=task_id), context ): - yield(MessageToJson(proto_utils.ToProto.stream_response(event))) + yield ( + MessageToJson(proto_utils.ToProto.stream_response(event)) + ) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e + raise A2AError(error=e.error if e.error else InternalError()) from e async def get_push_notification( self, @@ -217,19 +211,21 @@ async def get_push_notification( task_id = request.path_params['id'] push_id = request.path_params['push_id'] if push_id: - params = GetTaskPushNotificationConfigParams(id=task_id, push_id=push_id) + params = GetTaskPushNotificationConfigParams( + id=task_id, push_id=push_id + ) else: params = TaskIdParams['id'] - config = await self.request_handler.on_get_task_push_notification_config( - params, context + config = ( + await self.request_handler.on_get_task_push_notification_config( + params, context + ) ) return MessageToJson( proto_utils.ToProto.task_push_notification_config(config) ) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) + raise A2AError(error=e.error if e.error else InternalError()) @validate( lambda self: self.agent_card.capabilities.pushNotifications, @@ -262,19 +258,21 @@ async def set_push_notification( params = a2a_pb2.TaskPushNotificationConfig() Parse(body, params) params = TaskPushNotificationConfig.validate_model(body) - a2a_request = proto_utils.FromProto.task_push_notification_config( - params, - ), - config = await self.request_handler.on_set_task_push_notification_config( - a2a_request, context + a2a_request = ( + proto_utils.FromProto.task_push_notification_config( + params, + ), + ) + config = ( + await self.request_handler.on_set_task_push_notification_config( + a2a_request, context + ) ) return MessageToJson( proto_utils.ToProto.task_push_notification_config(config) ) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e + raise A2AError(error=e.error if e.error else InternalError()) from e async def on_get_task( self, @@ -297,27 +295,25 @@ async def on_get_task( task_id = request.path_params['id'] historyLength = None if 'historyLength' in request.query_params: - historyLength = request.query_params['historyLength'] + historyLength = request.query_params['historyLength'] params = TaskQueryParams(id=task_id, historyLength=historyLength) task = await self.request_handler.on_get_task(params, context) if task: return MessageToJson(proto_utils.ToProto.task(task)) raise ServerError(error=TaskNotFoundError()) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e + raise A2AError(error=e.error if e.error else InternalError()) from e async def list_push_notifications( self, request: Request, context: ServerCallContext | None = None, ) -> list[TaskPushNotificationConfig]: - raise NotImplementedError("list notifications not implemented") + raise NotImplementedError('list notifications not implemented') async def list_tasks( self, request: Request, context: ServerCallContext | None = None, ) -> list[Task]: - raise NotImplementedError("list tasks not implemented") + raise NotImplementedError('list tasks not implemented') diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index fdf100f7..f8871c46 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -20,17 +20,20 @@ from a2a.server.events import EventQueue, InMemoryQueueManager, QueueManager from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import ( + InMemoryPushNotificationConfigStore, InMemoryTaskStore, + PushNotificationConfigStore, + PushNotificationSender, ResultAggregator, TaskStore, TaskUpdater, - PushNotificationConfigStore, - PushNotificationSender, - InMemoryPushNotificationConfigStore, ) from a2a.types import ( + DeleteTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigParams, InternalError, InvalidParamsError, + ListTaskPushNotificationConfigParams, Message, MessageSendConfiguration, MessageSendParams, @@ -46,9 +49,6 @@ TaskStatus, TextPart, UnsupportedOperationError, - GetTaskPushNotificationConfigParams, - ListTaskPushNotificationConfigParams, - DeleteTaskPushNotificationConfigParams, )