diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 033e07a9..f239af3e 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -13,7 +13,6 @@ build_error_response, prepare_response_object, ) -from a2a.server.request_handlers.rest_handler import RESTHandler logger = logging.getLogger(__name__) @@ -42,7 +41,6 @@ def __init__(self, *args, **kwargs): __all__ = [ 'DefaultRequestHandler', 'GrpcHandler', - 'RESTHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py deleted file mode 100644 index af889d9d..00000000 --- a/src/a2a/server/request_handlers/rest_handler.py +++ /dev/null @@ -1,334 +0,0 @@ -import logging - -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING, Any - -from google.protobuf.json_format import ( - MessageToDict, - Parse, -) - - -if TYPE_CHECKING: - from starlette.requests import Request -else: - try: - from starlette.requests import Request - except ImportError: - Request = Any - - -from a2a.server.context import ServerCallContext -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types import a2a_pb2 -from a2a.types.a2a_pb2 import ( - AgentCard, - CancelTaskRequest, - GetTaskPushNotificationConfigRequest, - SubscribeToTaskRequest, -) -from a2a.utils import constants, proto_utils -from a2a.utils.errors import TaskNotFoundError -from a2a.utils.helpers import ( - validate, - validate_version, -) -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -@trace_class(kind=SpanKind.SERVER) -class RESTHandler: - """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses. - - This uses the protobuf definitions of the gRPC service as the source of truth. By - doing this, it ensures that this implementation and the gRPC transcoding - (via Envoy) are equivalent. This handler should be used if using the gRPC handler - with Envoy is not feasible for a given deployment solution. Use this handler - and a related application if you desire to ONLY server the RESTful API. - """ - - def __init__( - self, - agent_card: AgentCard, - request_handler: RequestHandler, - ): - """Initializes the RESTHandler. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - request_handler: The underlying `RequestHandler` instance to delegate requests to. - """ - self.agent_card = agent_card - self.request_handler = request_handler - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_message_send( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'message/send' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the result (Task or Message) - """ - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - task_or_message = await self.request_handler.on_message_send( - params, context - ) - if isinstance(task_or_message, a2a_pb2.Task): - response = a2a_pb2.SendMessageResponse(task=task_or_message) - else: - response = a2a_pb2.SendMessageResponse(message=task_or_message) - return MessageToDict(response) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_message_send_stream( - self, - request: Request, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - """Handles the 'message/stream' REST method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Yields: - JSON serialized objects containing streaming events - (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON - """ - body = await request.body() - params = a2a_pb2.SendMessageRequest() - Parse(body, params) - async for event in self.request_handler.on_message_send_stream( - params, context - ): - response = proto_utils.to_stream_response(event) - yield MessageToDict(response) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_cancel_task( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/cancel' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the updated Task - """ - task_id = request.path_params['id'] - task = await self.request_handler.on_cancel_task( - CancelTaskRequest(id=task_id), context - ) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_subscribe_to_task( - self, - request: Request, - context: ServerCallContext, - ) -> AsyncIterator[dict[str, Any]]: - """Handles the 'SubscribeToTask' REST method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Yields: - JSON serialized objects containing streaming events - """ - task_id = request.path_params['id'] - async for event in self.request_handler.on_subscribe_to_task( - SubscribeToTaskRequest(id=task_id), context - ): - yield MessageToDict(proto_utils.to_stream_response(event)) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def get_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/get' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `dict` containing the config - """ - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = GetTaskPushNotificationConfigRequest( - task_id=task_id, - id=push_id, - ) - config = ( - await self.request_handler.on_get_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) - async def set_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/set' REST method. - - Requires the agent to support push notifications. - - Args: - request: The incoming `TaskPushNotificationConfig` object. - context: Context provided by the server. - - Returns: - A `dict` containing the config object. - - Raises: - UnsupportedOperationError: If push notifications are not supported by the agent - (due to the `@validate` decorator), A2AError if processing error is - found. - """ - body = await request.body() - params = a2a_pb2.TaskPushNotificationConfig() - Parse(body, params) - # Set the parent to the task resource name format - params.task_id = request.path_params['id'] - config = ( - await self.request_handler.on_create_task_push_notification_config( - params, context - ) - ) - return MessageToDict(config) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_get_task( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/{id}' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A `Task` object containing the Task. - """ - params = a2a_pb2.GetTaskRequest() - proto_utils.parse_params(request.query_params, params) - params.id = request.path_params['id'] - task = await self.request_handler.on_get_task(params, context) - if task: - return MessageToDict(task) - raise TaskNotFoundError - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def delete_push_notification( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/delete' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - An empty `dict` representing the empty response. - """ - task_id = request.path_params['id'] - push_id = request.path_params['push_id'] - params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( - task_id=task_id, id=push_id - ) - await self.request_handler.on_delete_task_push_notification_config( - params, context - ) - return {} - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def list_tasks( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/list' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A list of `dict` representing the `Task` objects. - """ - params = a2a_pb2.ListTasksRequest() - proto_utils.parse_params(request.query_params, params) - - result = await self.request_handler.on_list_tasks(params, context) - return MessageToDict(result, always_print_fields_with_no_presence=True) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def list_push_notifications( - self, - request: Request, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/list' REST method. - - Args: - request: The incoming `Request` object. - context: Context provided by the server. - - Returns: - A list of `dict` representing the `TaskPushNotificationConfig` objects. - """ - params = a2a_pb2.ListTaskPushNotificationConfigsRequest() - proto_utils.parse_params(request.query_params, params) - params.task_id = request.path_params['id'] - - result = ( - await self.request_handler.on_list_task_push_notification_configs( - params, context - ) - ) - return MessageToDict(result) diff --git a/src/a2a/server/routes/rest_dispatcher.py b/src/a2a/server/routes/rest_dispatcher.py new file mode 100644 index 00000000..76831508 --- /dev/null +++ b/src/a2a/server/routes/rest_dispatcher.py @@ -0,0 +1,388 @@ +import json +import logging + +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import TYPE_CHECKING, Any, TypeVar + +from google.protobuf.json_format import MessageToDict, Parse + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder +from a2a.types import a2a_pb2 +from a2a.types.a2a_pb2 import ( + AgentCard, + CancelTaskRequest, + GetTaskPushNotificationConfigRequest, + SubscribeToTaskRequest, +) +from a2a.utils import constants, proto_utils +from a2a.utils.error_handlers import ( + rest_error_handler, + rest_stream_error_handler, +) +from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, + InvalidRequestError, + TaskNotFoundError, +) +from a2a.utils.helpers import maybe_await, validate, validate_version +from a2a.utils.telemetry import SpanKind, trace_class + + +if TYPE_CHECKING: + from sse_starlette.sse import EventSourceResponse + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + _package_starlette_installed = True +else: + try: + from sse_starlette.sse import EventSourceResponse + from starlette.requests import Request + from starlette.responses import JSONResponse, Response + + _package_starlette_installed = True + except ImportError: + EventSourceResponse = Any + Request = Any + JSONResponse = Any + Response = Any + + _package_starlette_installed = False + +logger = logging.getLogger(__name__) + +TResponse = TypeVar('TResponse') + + +@trace_class(kind=SpanKind.SERVER) +class RestDispatcher: + """Dispatches incoming REST requests to the appropriate handler methods. + + Handles context building, routing to RequestHandler directly, and response formatting (JSON/SSE). + """ + + def __init__( # noqa: PLR0913 + self, + agent_card: AgentCard, + request_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, + context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard + ] + | None = None, + ) -> None: + """Initializes the RestDispatcher. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The underlying `RequestHandler` instance to delegate requests to. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the request_handler. If None, no + ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. + """ + if not _package_starlette_installed: + raise ImportError( + 'Packages `starlette` and `sse-starlette` are required to use the' + ' `RestDispatcher`. They can be added as a part of `a2a-sdk` ' + 'optional dependencies, `a2a-sdk[http-server]`.' + ) + + self.agent_card = agent_card + self.extended_agent_card = extended_agent_card + self.card_modifier = card_modifier + self.extended_card_modifier = extended_card_modifier + self._context_builder = context_builder or DefaultCallContextBuilder() + self.request_handler = request_handler + + def _build_call_context(self, request: Request) -> ServerCallContext: + call_context = self._context_builder.build(request) + if 'tenant' in request.path_params: + call_context.tenant = request.path_params['tenant'] + return call_context + + async def _handle_non_streaming( + self, + request: Request, + handler_func: Callable[[ServerCallContext], Awaitable[TResponse]], + ) -> TResponse: + """Centralized error handling and context management for unary calls.""" + context = self._build_call_context(request) + return await handler_func(context) + + async def _handle_streaming( + self, + request: Request, + handler_func: Callable[[ServerCallContext], AsyncIterator[Any]], + ) -> EventSourceResponse: + """Centralized error handling and context management for streaming calls.""" + # Pre-consume and cache the request body to prevent deadlock in streaming context + # This is required because Starlette's request.body() can only be consumed once, + # and attempting to consume it after EventSourceResponse starts causes deadlock + try: + await request.body() + except (ValueError, RuntimeError, OSError) as e: + raise InvalidRequestError( + message=f'Failed to pre-consume request body: {e}' + ) from e + + context = self._build_call_context(request) + + # Eagerly fetch the first item from the stream so that errors raised + # before any event is yielded (e.g. validation, parsing, or handler + # failures) propagate here and are caught by + # @rest_stream_error_handler, which returns a JSONResponse with + # the correct HTTP status code instead of starting an SSE stream. + # Without this, the error would be raised after SSE headers are + # already sent, and the client would see a broken stream instead + stream = aiter(handler_func(context)) + try: + first_item = await anext(stream) + except StopAsyncIteration: + return EventSourceResponse(iter([])) + + async def event_generator() -> AsyncIterator[str]: + yield json.dumps(first_item) + async for item in stream: + yield json.dumps(item) + + return EventSourceResponse(event_generator()) + + @rest_error_handler + async def on_message_send(self, request: Request) -> Response: + """Handles the 'message/send' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.SendMessageResponse: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + task_or_message = await self.request_handler.on_message_send( + params, context + ) + if isinstance(task_or_message, a2a_pb2.Task): + return a2a_pb2.SendMessageResponse(task=task_or_message) + return a2a_pb2.SendMessageResponse(message=task_or_message) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) + + @rest_stream_error_handler + async def on_message_send_stream( + self, request: Request + ) -> EventSourceResponse: + """Handles the 'message/stream' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda _: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def _handler( + context: ServerCallContext, + ) -> AsyncIterator[dict[str, Any]]: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + async for event in self.request_handler.on_message_send_stream( + params, context + ): + response = proto_utils.to_stream_response(event) + yield MessageToDict(response) + + return await self._handle_streaming(request, _handler) + + @rest_error_handler + async def on_cancel_task(self, request: Request) -> Response: + """Handles the 'tasks/cancel' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler(context: ServerCallContext) -> a2a_pb2.Task: + task_id = request.path_params['id'] + task = await self.request_handler.on_cancel_task( + CancelTaskRequest(id=task_id), context + ) + if task: + return task + raise TaskNotFoundError + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) + + @rest_stream_error_handler + async def on_subscribe_to_task( + self, request: Request + ) -> EventSourceResponse: + """Handles the 'SubscribeToTask' REST method.""" + task_id = request.path_params['id'] + + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda _: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def _handler( + context: ServerCallContext, + ) -> AsyncIterator[dict[str, Any]]: + async for event in self.request_handler.on_subscribe_to_task( + SubscribeToTaskRequest(id=task_id), context + ): + response = proto_utils.to_stream_response(event) + yield MessageToDict(response) + + return await self._handle_streaming(request, _handler) + + @rest_error_handler + async def on_get_task(self, request: Request) -> Response: + """Handles the 'tasks/{id}' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler(context: ServerCallContext) -> a2a_pb2.Task: + params = a2a_pb2.GetTaskRequest() + proto_utils.parse_params(request.query_params, params) + params.id = request.path_params['id'] + task = await self.request_handler.on_get_task(params, context) + if task: + return task + raise TaskNotFoundError + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) + + @rest_error_handler + async def get_push_notification(self, request: Request) -> Response: + """Handles the 'tasks/pushNotificationConfig/get' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.TaskPushNotificationConfig: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = GetTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + return ( + await self.request_handler.on_get_task_push_notification_config( + params, context + ) + ) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) + + @rest_error_handler + async def delete_push_notification(self, request: Request) -> Response: + """Handles the 'tasks/pushNotificationConfig/delete' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler(context: ServerCallContext) -> None: + task_id = request.path_params['id'] + push_id = request.path_params['push_id'] + params = a2a_pb2.DeleteTaskPushNotificationConfigRequest( + task_id=task_id, id=push_id + ) + await self.request_handler.on_delete_task_push_notification_config( + params, context + ) + + await self._handle_non_streaming(request, _handler) + return JSONResponse(content={}) + + @rest_error_handler + async def set_push_notification(self, request: Request) -> Response: + """Handles the 'tasks/pushNotificationConfig/set' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda _: self.agent_card.capabilities.push_notifications, + 'Push notifications are not supported by the agent', + ) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.TaskPushNotificationConfig: + body = await request.body() + params = a2a_pb2.TaskPushNotificationConfig() + Parse(body, params) + params.task_id = request.path_params['id'] + return await self.request_handler.on_create_task_push_notification_config( + params, context + ) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) + + @rest_error_handler + async def list_push_notifications(self, request: Request) -> Response: + """Handles the 'tasks/pushNotificationConfig/list' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.ListTaskPushNotificationConfigsResponse: + params = a2a_pb2.ListTaskPushNotificationConfigsRequest() + proto_utils.parse_params(request.query_params, params) + params.task_id = request.path_params['id'] + return await self.request_handler.on_list_task_push_notification_configs( + params, context + ) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse(content=MessageToDict(response)) + + @rest_error_handler + async def list_tasks(self, request: Request) -> Response: + """Handles the 'tasks/list' REST method.""" + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _handler( + context: ServerCallContext, + ) -> a2a_pb2.ListTasksResponse: + params = a2a_pb2.ListTasksRequest() + proto_utils.parse_params(request.query_params, params) + return await self.request_handler.on_list_tasks(params, context) + + response = await self._handle_non_streaming(request, _handler) + return JSONResponse( + content=MessageToDict( + response, always_print_fields_with_no_presence=True + ) + ) + + @rest_error_handler + async def handle_authenticated_agent_card( + self, request: Request + ) -> Response: + """Handles the 'extendedAgentCard' REST method.""" + if not self.agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='Authenticated card not supported' + ) + card_to_serve = self.extended_agent_card or self.agent_card + + if self.extended_card_modifier: + context = self._build_call_context(request) + card_to_serve = await maybe_await( + self.extended_card_modifier(card_to_serve, context) + ) + elif self.card_modifier: + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) + + return JSONResponse( + content=MessageToDict( + card_to_serve, preserving_proto_field_name=True + ) + ) diff --git a/src/a2a/server/routes/rest_routes.py b/src/a2a/server/routes/rest_routes.py index 1923f038..85dd01ff 100644 --- a/src/a2a/server/routes/rest_routes.py +++ b/src/a2a/server/routes/rest_routes.py @@ -1,27 +1,16 @@ -import functools -import json import logging -from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any -from google.protobuf.json_format import MessageToDict - from a2a.compat.v0_3.rest_adapter import REST03Adapter from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.rest_handler import RESTHandler -from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder -from a2a.types.a2a_pb2 import AgentCard -from a2a.utils.error_handlers import ( - rest_error_handler, - rest_stream_error_handler, -) -from a2a.utils.errors import ( - ExtendedAgentCardNotConfiguredError, - InvalidRequestError, +from a2a.server.routes import CallContextBuilder +from a2a.server.routes.rest_dispatcher import RestDispatcher +from a2a.types.a2a_pb2 import ( + AgentCard, ) -from a2a.utils.helpers import maybe_await if TYPE_CHECKING: @@ -94,7 +83,16 @@ def create_rest_routes( # noqa: PLR0913 'optional dependencies, `a2a-sdk[http-server]`.' ) - v03_routes = {} + dispatcher = RestDispatcher( + agent_card=agent_card, + request_handler=request_handler, + extended_agent_card=extended_agent_card, + context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, + ) + + routes: list[BaseRoute] = [] if enable_v0_3_compat: v03_adapter = REST03Adapter( agent_card=agent_card, @@ -105,139 +103,43 @@ def create_rest_routes( # noqa: PLR0913 extended_card_modifier=extended_card_modifier, ) v03_routes = v03_adapter.routes() - - routes: list[BaseRoute] = [] - for (path, method), endpoint in v03_routes.items(): - routes.append( - Route( - path=f'{path_prefix}{path}', - endpoint=endpoint, - methods=[method], + for (path, method), endpoint in v03_routes.items(): + routes.append( + Route( + path=f'{path_prefix}{path}', + endpoint=endpoint, + methods=[method], + ) ) - ) - handler = RESTHandler( - agent_card=agent_card, request_handler=request_handler - ) - _context_builder = context_builder or DefaultCallContextBuilder() - - def _build_call_context(request: 'Request') -> ServerCallContext: - call_context = _context_builder.build(request) - if 'tenant' in request.path_params: - call_context.tenant = request.path_params['tenant'] - return call_context - - @rest_error_handler - async def _handle_request( - method: Callable[['Request', ServerCallContext], Awaitable[Any]], - request: 'Request', - ) -> 'Response': - - call_context = _build_call_context(request) - response = await method(request, call_context) - return JSONResponse(content=response) - - @rest_stream_error_handler - async def _handle_streaming_request( - method: Callable[[Request, ServerCallContext], AsyncIterable[Any]], - request: Request, - ) -> EventSourceResponse: - # Pre-consume and cache the request body to prevent deadlock in streaming context - # This is required because Starlette's request.body() can only be consumed once, - # and attempting to consume it after EventSourceResponse starts causes deadlock - try: - await request.body() - except (ValueError, RuntimeError, OSError) as e: - raise InvalidRequestError( - message=f'Failed to pre-consume request body: {e}' - ) from e - - call_context = _build_call_context(request) - - # Eagerly fetch the first item from the stream so that errors raised - # before any event is yielded (e.g. validation, parsing, or handler - # failures) propagate here and are caught by - # @rest_stream_error_handler, which returns a JSONResponse with - # the correct HTTP status code instead of starting an SSE stream. - # Without this, the error would be raised after SSE headers are - # already sent, and the client would see a broken stream instead - # of a proper error response. - stream = aiter(method(request, call_context)) - try: - first_item = await anext(stream) - except StopAsyncIteration: - return EventSourceResponse(iter([])) - - async def event_generator() -> AsyncIterator[str]: - yield json.dumps(first_item) - async for item in stream: - yield json.dumps(item) - - return EventSourceResponse(event_generator()) - - async def _handle_authenticated_agent_card( - request: 'Request', call_context: ServerCallContext | None = None - ) -> dict[str, Any]: - if not agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='Authenticated card not supported' - ) - card_to_serve = extended_agent_card or agent_card - - if extended_card_modifier: - # Re-generate context if none passed to replicate RESTAdapter exact logic - context = call_context or _build_call_context(request) - card_to_serve = await maybe_await( - extended_card_modifier(card_to_serve, context) - ) - elif card_modifier: - card_to_serve = await maybe_await(card_modifier(card_to_serve)) - - return MessageToDict(card_to_serve, preserving_proto_field_name=True) - - # Dictionary of routes, mapping to bound helper methods - base_routes: dict[tuple[str, str], Callable[[Request], Any]] = { - ('/message:send', 'POST'): functools.partial( - _handle_request, handler.on_message_send - ), - ('/message:stream', 'POST'): functools.partial( - _handle_streaming_request, - handler.on_message_send_stream, - ), - ('/tasks/{id}:cancel', 'POST'): functools.partial( - _handle_request, handler.on_cancel_task - ), - ('/tasks/{id}:subscribe', 'GET'): functools.partial( - _handle_streaming_request, - handler.on_subscribe_to_task, - ), - ('/tasks/{id}:subscribe', 'POST'): functools.partial( - _handle_streaming_request, - handler.on_subscribe_to_task, - ), - ('/tasks/{id}', 'GET'): functools.partial( - _handle_request, handler.on_get_task - ), + base_routes = { + ('/message:send', 'POST'): dispatcher.on_message_send, + ('/message:stream', 'POST'): dispatcher.on_message_send_stream, + ('/tasks/{id}:cancel', 'POST'): dispatcher.on_cancel_task, + ('/tasks/{id}:subscribe', 'GET'): dispatcher.on_subscribe_to_task, + ('/tasks/{id}:subscribe', 'POST'): dispatcher.on_subscribe_to_task, + ('/tasks/{id}', 'GET'): dispatcher.on_get_task, ( '/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET', - ): functools.partial(_handle_request, handler.get_push_notification), + ): dispatcher.get_push_notification, ( '/tasks/{id}/pushNotificationConfigs/{push_id}', 'DELETE', - ): functools.partial(_handle_request, handler.delete_push_notification), - ('/tasks/{id}/pushNotificationConfigs', 'POST'): functools.partial( - _handle_request, handler.set_push_notification - ), - ('/tasks/{id}/pushNotificationConfigs', 'GET'): functools.partial( - _handle_request, handler.list_push_notifications - ), - ('/tasks', 'GET'): functools.partial( - _handle_request, handler.list_tasks - ), - ('/extendedAgentCard', 'GET'): functools.partial( - _handle_request, _handle_authenticated_agent_card - ), + ): dispatcher.delete_push_notification, + ( + '/tasks/{id}/pushNotificationConfigs', + 'POST', + ): dispatcher.set_push_notification, + ( + '/tasks/{id}/pushNotificationConfigs', + 'GET', + ): dispatcher.list_push_notifications, + ('/tasks', 'GET'): dispatcher.list_tasks, + ( + '/extendedAgentCard', + 'GET', + ): dispatcher.handle_authenticated_agent_card, } base_route_objects = [] diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index e5b37e5f..b1f23b40 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -372,7 +372,7 @@ def _is_version_compatible(actual: str) -> bool: @functools.wraps(func) def async_gen_wrapper( - self: Any, *args: Any, **kwargs: Any + *args: Any, **kwargs: Any ) -> AsyncIterator[Any]: actual_version = _get_actual_version(args, kwargs) if not _is_version_compatible(actual_version): @@ -385,12 +385,12 @@ def async_gen_wrapper( message=f"A2A version '{actual_version}' is not supported by this handler. " f"Expected version '{expected_version}'." ) - return func(self, *args, **kwargs) + return func(*args, **kwargs) return cast('F', async_gen_wrapper) @functools.wraps(func) - async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: actual_version = _get_actual_version(args, kwargs) if not _is_version_compatible(actual_version): logger.warning( @@ -402,7 +402,7 @@ async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: message=f"A2A version '{actual_version}' is not supported by this handler. " f"Expected version '{expected_version}'." ) - return await func(self, *args, **kwargs) + return await func(*args, **kwargs) return cast('F', async_wrapper) diff --git a/tests/server/routes/test_rest_dispatcher.py b/tests/server/routes/test_rest_dispatcher.py new file mode 100644 index 00000000..bee9424f --- /dev/null +++ b/tests/server/routes/test_rest_dispatcher.py @@ -0,0 +1,329 @@ +import json +from collections.abc import AsyncIterator +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from starlette.requests import Request +from starlette.responses import JSONResponse + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.routes import rest_dispatcher +from a2a.server.routes.rest_dispatcher import ( + DefaultCallContextBuilder, + RestDispatcher, +) +from a2a.types.a2a_pb2 import ( + AgentCapabilities, + AgentCard, + Message, + SendMessageResponse, + Task, + TaskPushNotificationConfig, + ListTasksResponse, + ListTaskPushNotificationConfigsResponse, +) +from a2a.utils.errors import ( + ExtendedAgentCardNotConfiguredError, + TaskNotFoundError, + UnsupportedOperationError, +) + + +@pytest.fixture +def mock_handler(): + handler = AsyncMock(spec=RequestHandler) + # Default success cases + handler.on_message_send.return_value = Message(message_id='test_msg') + handler.on_cancel_task.return_value = Task(id='test_task') + handler.on_get_task.return_value = Task(id='test_task') + handler.on_list_tasks.return_value = ListTasksResponse() + handler.on_get_task_push_notification_config.return_value = ( + TaskPushNotificationConfig(url='http://test') + ) + handler.on_create_task_push_notification_config.return_value = ( + TaskPushNotificationConfig(url='http://test') + ) + handler.on_list_task_push_notification_configs.return_value = ( + ListTaskPushNotificationConfigsResponse() + ) + + # Streaming mocks + async def mock_stream(*args, **kwargs) -> AsyncIterator[Task]: + yield Task(id='chunk1') + yield Task(id='chunk2') + + handler.on_message_send_stream.side_effect = mock_stream + handler.on_subscribe_to_task.side_effect = mock_stream + return handler + + +@pytest.fixture +def agent_card(): + card = MagicMock(spec=AgentCard) + card.capabilities = AgentCapabilities( + streaming=True, + push_notifications=True, + extended_agent_card=True, + ) + return card + + +@pytest.fixture +def rest_dispatcher_instance(agent_card, mock_handler): + return RestDispatcher(agent_card=agent_card, request_handler=mock_handler) + + +from starlette.datastructures import Headers + + +def make_mock_request( + method: str = 'GET', + path_params: dict | None = None, + query_params: dict | None = None, + headers: dict | None = None, + body: bytes = b'{}', +) -> Request: + mock_req = MagicMock(spec=Request) + mock_req.method = method + mock_req.path_params = path_params or {} + mock_req.query_params = query_params or {} + + # Default valid headers for A2A + default_headers = {'a2a-version': '1.0'} + if headers: + default_headers.update(headers) + + mock_req.headers = Headers(default_headers) + mock_req.body = AsyncMock(return_value=body) + + # Needs to be able to build ServerCallContext, so provide .user and .auth etc. if needed + mock_req.user = MagicMock(is_authenticated=False) + mock_req.auth = None + return mock_req + + +class TestRestDispatcherInitialization: + @pytest.fixture(scope='class') + def mark_pkg_starlette_not_installed(self): + pkg_starlette_installed_flag = ( + rest_dispatcher._package_starlette_installed + ) + rest_dispatcher._package_starlette_installed = False + yield + rest_dispatcher._package_starlette_installed = ( + pkg_starlette_installed_flag + ) + + def test_missing_starlette_raises_importerror( + self, mark_pkg_starlette_not_installed, agent_card, mock_handler + ): + with pytest.raises( + ImportError, + match='Packages `starlette` and `sse-starlette` are required', + ): + RestDispatcher(agent_card=agent_card, request_handler=mock_handler) + + +@pytest.mark.asyncio +class TestRestDispatcherContextManagement: + async def test_build_call_context(self, rest_dispatcher_instance): + req = make_mock_request(path_params={'tenant': 'my-tenant'}) + context = rest_dispatcher_instance._build_call_context(req) + + assert isinstance(context, ServerCallContext) + assert context.tenant == 'my-tenant' + assert context.state['headers']['a2a-version'] == '1.0' + + +@pytest.mark.asyncio +class TestRestDispatcherEndpoints: + async def test_on_message_send_throws_error_for_unsupported_version( + self, rest_dispatcher_instance, mock_handler + ): + # 0.3 is currently not supported for direct message sending on RestDispatcher + req = make_mock_request(method='POST', headers={'a2a-version': '0.3.0'}) + response = await rest_dispatcher_instance.on_message_send(req) + + # VersionNotSupportedError maps to 400 Bad Request + assert response.status_code == 400 + + async def test_on_message_send_returns_message( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request(method='POST') + response = await rest_dispatcher_instance.on_message_send(req) + + assert isinstance(response, JSONResponse) + assert response.status_code == 200 + data = json.loads(response.body) + assert 'message' in data + + async def test_on_message_send_returns_task( + self, rest_dispatcher_instance, mock_handler + ): + mock_handler.on_message_send.return_value = Task(id='new_task') + req = make_mock_request(method='POST') + + response = await rest_dispatcher_instance.on_message_send(req) + assert response.status_code == 200 + data = json.loads(response.body) + assert 'task' in data + assert data['task']['id'] == 'new_task' + + async def test_on_cancel_task_success( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request(method='POST', path_params={'id': 'test_task'}) + response = await rest_dispatcher_instance.on_cancel_task(req) + + assert response.status_code == 200 + data = json.loads(response.body) + assert data['id'] == 'test_task' + + async def test_on_cancel_task_not_found( + self, rest_dispatcher_instance, mock_handler + ): + mock_handler.on_cancel_task.return_value = None + req = make_mock_request(method='POST', path_params={'id': 'test_task'}) + + response = await rest_dispatcher_instance.on_cancel_task(req) + assert response.status_code == 404 # TaskNotFoundError maps to 404 + + async def test_on_get_task_success( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request(method='GET', path_params={'id': 'test_task'}) + response = await rest_dispatcher_instance.on_get_task(req) + + assert response.status_code == 200 + data = json.loads(response.body) + assert data['id'] == 'test_task' + + async def test_on_get_task_not_found( + self, rest_dispatcher_instance, mock_handler + ): + mock_handler.on_get_task.return_value = None + req = make_mock_request( + method='GET', path_params={'id': 'missing_task'} + ) + + response = await rest_dispatcher_instance.on_get_task(req) + assert response.status_code == 404 + + async def test_list_tasks(self, rest_dispatcher_instance, mock_handler): + req = make_mock_request(method='GET') + response = await rest_dispatcher_instance.list_tasks(req) + assert response.status_code == 200 + + async def test_get_push_notification( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request( + method='GET', path_params={'id': 'task1', 'push_id': 'push1'} + ) + response = await rest_dispatcher_instance.get_push_notification(req) + assert response.status_code == 200 + data = json.loads(response.body) + assert data['url'] == 'http://test' + + async def test_delete_push_notification( + self, rest_dispatcher_instance, mock_handler + ): + req = make_mock_request( + method='DELETE', path_params={'id': 'task1', 'push_id': 'push1'} + ) + response = await rest_dispatcher_instance.delete_push_notification(req) + assert response.status_code == 200 + + async def test_set_push_notification_disabled_raises( + self, agent_card, mock_handler + ): + agent_card.capabilities.push_notifications = False + dispatcher = RestDispatcher( + agent_card=agent_card, request_handler=mock_handler + ) + req = make_mock_request(method='POST', path_params={'id': 'task1'}) + + response = await dispatcher.set_push_notification(req) + assert response.status_code == 400 # UnsupportedOperation maps to 400 + + async def test_handle_authenticated_agent_card( + self, rest_dispatcher_instance + ): + req = make_mock_request() + response = ( + await rest_dispatcher_instance.handle_authenticated_agent_card(req) + ) + assert response.status_code == 200 + + async def test_handle_authenticated_agent_card_unsupported( + self, agent_card, mock_handler + ): + agent_card.capabilities.extended_agent_card = False + dispatcher = RestDispatcher( + agent_card=agent_card, request_handler=mock_handler + ) + req = make_mock_request() + + response = await dispatcher.handle_authenticated_agent_card(req) + assert response.status_code == 400 + + +@pytest.mark.asyncio +class TestRestDispatcherStreaming: + async def test_on_message_send_stream_unsupported( + self, agent_card, mock_handler + ): + agent_card.capabilities.streaming = False + dispatcher = RestDispatcher( + agent_card=agent_card, request_handler=mock_handler + ) + req = make_mock_request(method='POST') + + response = await dispatcher.on_message_send_stream(req) + assert response.status_code == 400 + + async def test_on_subscribe_to_task_unsupported( + self, agent_card, mock_handler + ): + agent_card.capabilities.streaming = False + dispatcher = RestDispatcher( + agent_card=agent_card, request_handler=mock_handler + ) + req = make_mock_request(method='GET', path_params={'id': 't1'}) + + response = await dispatcher.on_subscribe_to_task(req) + assert response.status_code == 400 + + async def test_on_message_send_stream_success( + self, rest_dispatcher_instance + ): + req = make_mock_request(method='POST') + response = await rest_dispatcher_instance.on_message_send_stream(req) + + assert response.status_code == 200 + + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + assert len(chunks) == 2 + # sse-starlette yields strings or bytes formatted as Server-Sent Events + assert 'chunk1' in str(chunks[0]) + assert 'chunk2' in str(chunks[1]) + + async def test_on_subscribe_to_task_success(self, rest_dispatcher_instance): + req = make_mock_request(method='GET', path_params={'id': 'test_task'}) + response = await rest_dispatcher_instance.on_subscribe_to_task(req) + + assert response.status_code == 200 + + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk) + + assert len(chunks) == 2 + assert 'chunk1' in str(chunks[0]) + assert 'chunk2' in str(chunks[1])