From 59e3f5fd1509b14f7f8676aa4fd45289874ed084 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 11 Jul 2025 15:17:09 +0000 Subject: [PATCH 01/17] feat: [WIP] Add dedicated RESTful handler --- src/a2a/server/apps/__init__.py | 6 + src/a2a/server/apps/jsonrpc/__init__.py | 4 + src/a2a/server/apps/rest/__init__.py | 9 + src/a2a/server/apps/rest/fastapi_app.py | 83 +++++ src/a2a/server/apps/rest/rest_app.py | 227 +++++++++++++ src/a2a/server/request_handlers/__init__.py | 2 + .../server/request_handlers/rest_handler.py | 318 ++++++++++++++++++ 7 files changed, 649 insertions(+) create mode 100644 src/a2a/server/apps/rest/__init__.py create mode 100644 src/a2a/server/apps/rest/fastapi_app.py create mode 100644 src/a2a/server/apps/rest/rest_app.py create mode 100644 src/a2a/server/request_handlers/rest_handler.py diff --git a/src/a2a/server/apps/__init__.py b/src/a2a/server/apps/__init__.py index a73e05c8..4d42ee8c 100644 --- a/src/a2a/server/apps/__init__.py +++ b/src/a2a/server/apps/__init__.py @@ -6,11 +6,17 @@ CallContextBuilder, JSONRPCApplication, ) +from a2a.server.apps.rest import ( + A2ARESTFastAPIApplication, + RESTApplication, +) __all__ = [ 'A2AFastAPIApplication', + 'A2ARESTFastAPIApplication', 'A2AStarletteApplication', 'CallContextBuilder', 'JSONRPCApplication', + 'RESTApplication', ] diff --git a/src/a2a/server/apps/jsonrpc/__init__.py b/src/a2a/server/apps/jsonrpc/__init__.py index ab803d4e..b322f0ef 100644 --- a/src/a2a/server/apps/jsonrpc/__init__.py +++ b/src/a2a/server/apps/jsonrpc/__init__.py @@ -4,6 +4,8 @@ from a2a.server.apps.jsonrpc.jsonrpc_app import ( CallContextBuilder, JSONRPCApplication, + StarletteUserProxy, + DefaultCallContextBuilder, ) from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication @@ -13,4 +15,6 @@ 'A2AStarletteApplication', 'CallContextBuilder', 'JSONRPCApplication', + 'StarletteUserProxy', + 'DefaultCallContextBuilder', ] diff --git a/src/a2a/server/apps/rest/__init__.py b/src/a2a/server/apps/rest/__init__.py new file mode 100644 index 00000000..81ee7b7a --- /dev/null +++ b/src/a2a/server/apps/rest/__init__.py @@ -0,0 +1,9 @@ +"""A2A REST Applications.""" + +from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication +from a2a.server.apps.rest.rest_app import RESTApplication + +__all__ = [ + 'A2ARESTFastAPIApplication', + 'RESTApplication', +] diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py new file mode 100644 index 00000000..c6356560 --- /dev/null +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -0,0 +1,83 @@ +import logging + +from typing import Any + +from fastapi import FastAPI, Request, Response, APIRouter + +from a2a.server.apps.jsonrpc.jsonrpc_app import ( + CallContextBuilder, +) +from a2a.server.apps.rest.rest_app import ( + RESTApplication, +) +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import AgentCard + + +logger = logging.getLogger(__name__) + + +class A2ARESTFastAPIApplication: + """A FastAPI application implementing the A2A protocol server REST endpoints. + + Handles incoming REST requests, routes them to the appropriate + handler methods, and manages response generation including Server-Sent Events + (SSE). + """ + + def __init__( + self, + agent_card: AgentCard, + http_handler: RequestHandler, + context_builder: CallContextBuilder | None = None, + ): + """Initializes the A2ARESTFastAPIApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + """ + self._handler = RESTApplication( + agent_card=agent_card, + http_handler=http_handler, + context_builder=context_builder, + ) + + def build( + self, + agent_card_url: str = '/.well-known/agent.json', + rpc_url: str = '', + **kwargs: Any, + ) -> FastAPI: + """Builds and returns the FastAPI application instance. + + Args: + agent_card_url: The URL for the agent card endpoint. + rpc_url: The URL for the A2A JSON-RPC endpoint. + extended_agent_card_url: The URL for the authenticated extended agent card endpoint. + **kwargs: Additional keyword arguments to pass to the FastAPI constructor. + + Returns: + A configured FastAPI application instance. + """ + app = FastAPI(**kwargs) + router = APIRouter() + for route, callback in self._handler.routes().items(): + router.add_api_route( + f'{rpc_url}{route}', + callback[0], + methods=[callback[1]] + ) + + @router.get(f'{rpc_url}{agent_card_url}') + async def get_agent_card(request: Request) -> Response: + return await self._handle_get_agent_card(request) + + app.include_router(router) + return app diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py new file mode 100644 index 00000000..b4b9d3ad --- /dev/null +++ b/src/a2a/server/apps/rest/rest_app.py @@ -0,0 +1,227 @@ +import contextlib +import json +import logging +import traceback +import functools + +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable +from typing import Any, Tuple, Callable +from fastapi import FastAPI +from pydantic import BaseModel, ValidationError + +from sse_starlette.sse import EventSourceResponse +from starlette.applications import Starlette +from starlette.authentication import BaseUser +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from a2a.auth.user import UnauthenticatedUser +from a2a.auth.user import User as A2AUser +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.rest_handler import ( + RESTHandler, +) +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + A2AError, + AgentCard, + UnsupportedOperationError, + InternalError, +) +from a2a.utils.errors import MethodNotImplementedError +from a2a.server.apps.jsonrpc import ( + CallContextBuilder, + StarletteUserProxy, + DefaultCallContextBuilder +) + + +logger = logging.getLogger(__name__) + + +class RESTApplication: + """Base class for A2A REST applications. + + Defines REST requests processors and the routes to attach them too, as well as + manages response generation including Server-Sent Events (SSE). + """ + + def __init__( + self, + agent_card: AgentCard, + http_handler: RequestHandler, + context_builder: CallContextBuilder | None = None, + ): + """Initializes the RESTApplication. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + http_handler: The handler instance responsible for processing A2A + requests via http. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. + """ + self.agent_card = agent_card + self.handler = RESTHandler( + agent_card=agent_card, request_handler=http_handler + ) + self._context_builder = context_builder or DefaultCallContextBuilder() + + def _generate_error_response(self, error) -> JSONResponse: + """Creates a JSONResponse for a errors. + + Logs the error based on its type. + + Args: + error: The Error object. + + Returns: + A `JSONResponse` object formatted as a JSON error response. + """ + log_level = ( + logging.ERROR + if isinstance(error, InternalError) + else logging.WARNING + ) + logger.log( + log_level, + 'Request Error: ' + f"Code={error.code}, Message='{error.message}'" + f'{", Data=" + str(error.data) if error.data else ""}', + ) + return JSONResponse( + '{"message": ' + error.message + '}', + status_code=404, + ) + + def _handle_error(self, error: Exception) -> JSONResponse: + traceback.print_exc() + if isinstance(error, MethodNotImplementedError): + return self._generate_error_response(UnsupportedOperationError()) + elif isinstance(error, json.decoder.JSONDecodeError): + return self._generate_error_response( + JSONParseError(message=str(error)) + ) + elif isinstance(error, ValidationError): + return self._generate_error_response( + InvalidRequestError(data=json.loads(error.json())), + ) + logger.error(f'Unhandled exception: {error}') + return self._generate_error_response( + InternalError(message=str(error)) + ) + + async def _handle_request( + self, + method: Callable[[Request, ServerCallContext], Awaitable[str]], + request: Request + ) -> JSONResponse: + try: + call_context = self._context_builder.build(request) + response = await method(request, call_context) + return JSONResponse(content=response) + except Exception as e: + return self._handle_error(e) + + async def _handle_streaming_request( + self, + method: Callable[[Request, ServerCallContext], AsyncIterator[str]], + request: Request + ) -> EventSourceResponse: + try: + call_context = self._context_builder.build(request) + async def event_generator( + stream: AsyncGenerator[str], + ) -> AsyncGenerator[dict[str, str]]: + async for item in stream: + yield {'data': item} + return EventSourceResponse(event_generator(method(request, call_context))) + except Exception as e: + return self._handle_error(e) + + + async def _handle_get_agent_card(self, request: Request) -> JSONResponse: + """Handles GET requests for the agent card endpoint. + + Args: + request: The incoming Starlette Request object. + + Returns: + A JSONResponse containing the agent card data. + """ + # The public agent card is a direct serialization of the agent_card + # provided at initialization. + return JSONResponse( + self.agent_card.model_dump(mode='json', exclude_none=True) + ) + + async def handle_authenticated_agent_card(self, request: Request) -> JSONResponse: + """Hook for per credential agent card response. + + If a dynamic card is needed based on the credentials provided in the request + override this method and return the customized content. + + Args: + request: The incoming Starlette Request object. + + Returns: + A JSONResponse containing the authenticated card. + """ + if not self.agent_card.supportsAuthenticatedExtendedCard: + return JSONResponse( + '{"detail": "Authenticated card not supported"}', status_code=404 + ) + return JSONResponse( + self.agent_card.model_dump(mode='json', exclude_none=True) + ) + + def routes(self) -> dict[str, Tuple[Callable[[Request],Any], str]]: + routes = { + '/v1/message:send': ( + functools.partial( + self._handle_request, + self.handler.on_message_send), + 'POST'), + '/v1/message:stream': ( + functools.partial( + self._handle_streaming_request, + self.handler.on_message_send_stream), + 'POST'), + '/v1/tasks/{id}:subscribe': ( + functools.partial( + self._handle_streaming_request, + self.handler.on_resubscribe_to_task), + 'POST'), + '/v1/tasks/{id}': ( + functools.partial( + self._handle_request, + self.handler.on_get_task), + 'GET'), + '/v1/tasks/{id}/pushNotificationConfigs/{push_id}': ( + functools.partial( + self._handle_request, + self.handler.get_push_notification), + 'GET'), + '/v1/tasks/{id}/pushNotificationConfigs': ( + functools.partial( + self._handle_request, + self.handler.set_push_notification), + 'POST'), + '/v1/tasks/{id}/pushNotificationConfigs': ( + functools.partial( + self._handle_request, + self.handler.list_push_notifications), + 'GET'), + '/v1/tasks': ( + functools.partial( + self._handle_request, + self.handler.list_tasks), + 'GET'), + } + if self.agent_card.supportsAuthenticatedExtendedCard: + routes['/v1/card'] = ( + self.handle_authenticated_agent_card, + 'GET') + return routes diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 6e5603de..20854087 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -7,6 +7,7 @@ ) from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.rest_handler import RESTHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, prepare_response_object, @@ -40,6 +41,7 @@ def __init__(self, *args, **kwargs): 'GrpcHandler', 'JSONRPCHandler', 'RequestHandler', + 'RESTHandler', 'build_error_response', 'prepare_response_object', ] diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py new file mode 100644 index 00000000..6f6167a5 --- /dev/null +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -0,0 +1,318 @@ +import logging + +from collections.abc import AsyncIterable +from starlette.requests import Request +from pydantic import BaseModel, Field, RootModel + +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + A2AError, + AgentCard, + InternalError, + Message, + Task, + TaskArtifactUpdateEvent, + TaskNotFoundError, + TaskPushNotificationConfig, + TaskStatusUpdateEvent, + GetTaskPushNotificationConfigParams, + MessageSendParams, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, +) +from a2a.utils.errors import ServerError +from a2a.utils.helpers import validate +from a2a.utils.telemetry import SpanKind, trace_class +from a2a.grpc import a2a_pb2 +from a2a.utils import proto_utils +from google.protobuf.json_format import Parse, MessageToJson + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.SERVER) +class RESTHandler: + """Maps incoming REST-like (JSON+HTTP) requests to the appropriate request handler method and formats responses. + + This uses the protobuf definitions of the gRPC service as the source of truth. By + doing this, it ensures that this implementation and the gRPC transcoding + (via Envoy) are equivalent. This handler should be used if using the gRPC handler + with Envoy is not feasible for a given deployment solution. Use this handler + and a related application if you desire to ONLY server the RESTful API. + """ + + def __init__( + self, + agent_card: AgentCard, + request_handler: RequestHandler, + ): + """Initializes the RESTHandler. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The underlying `RequestHandler` instance to delegate requests to. + """ + self.agent_card = agent_card + self.request_handler = request_handler + + async def on_message_send( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'message/send' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `str` containing the JSON result (Task or Message) + Raises: + A2AError if a `ServerError` is raised by the handler. + """ + # TODO: Wrap in error handler to return error states + try: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + params, + ) + task_or_message = await self.request_handler.on_message_send( + a2a_request, context + ) + return MessageToJson(proto_utils.ToProto.task_or_message(task_or_message)) + except ServerError as e: + return A2AError( + error=e.error if e.error else InternalError() + ) + + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_message_send_stream( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> AsyncIterable[str]: + """Handles the 'message/stream' REST method. + + Yields response objects as they are produced by the underlying handler's stream. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Yields: + `str` objects containing streaming events + (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) as JSON + Raises: + `A2AError` + """ + try: + body = await request.body() + params = a2a_pb2.SendMessageRequest() + Parse(body, params) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + params, + ) + async for event in self.request_handler.on_message_send_stream( + a2a_request, context + ): + response = proto_utils.ToProto.stream_response(event) + yield MessageToJson(response) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError() + ) from e + return + + async def on_cancel_task( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'tasks/cancel' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `str` containing the updated Task in JSON format + Raises: + A2AError. + """ + try: + task_id = request.path_params['id'] + task = await self.request_handler.on_cancel_task( + TaskIdParams(id=task_id), context + ) + if task: + return MessageToJson(proto_utils.ToProto.task(task)) + raise ServerError(error=TaskNotFoundError()) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError(), + ) from e + + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def on_resubscribe_to_task( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> AsyncIterable[str]: + """Handles the 'tasks/resubscribe' REST method. + + Yields response objects as they are produced by the underlying handler's stream. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Yields: + `str` containing streaming events in JSON format + + Raises: + A A2AError if an error is encountered + """ + try: + task_id = request.path_params['id'] + async for event in self.request_handler.on_resubscribe_to_task( + TaskIdParams(id=task_id), context + ): + yield(MessageToJson(proto_utils.ToProto.stream_response(event))) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError() + ) from e + + async def get_push_notification( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'tasks/pushNotificationConfig/get' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `str` containing the config as JSON + Raises: + A2AError. + """ + try: + task_id = request.path_params['task_id'] + push_id = request.path_params['push_id'] + if push_id: + params = GetTaskPushNotificationConfigParams(id=task_id, push_id=push_id) + else: + params = TaskIdParams['task_id'] + config = await self.request_handler.on_get_task_push_notification_config( + params, context + ) + return MessageToJson( + proto_utils.ToProto.task_push_notification_config(config) + ) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError() + ) + + @validate( + lambda self: self.agent_card.capabilities.pushNotifications, + 'Push notifications are not supported by the agent', + ) + async def set_push_notification( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'tasks/pushNotificationConfig/set' REST method. + + Requires the agent to support push notifications. + + Args: + request: The incoming `TaskPushNotificationConfig` object. + context: Context provided by the server. + + Returns: + A `str` containing the config as JSON object. + + Raises: + ServerError: If push notifications are not supported by the agent + (due to the `@validate` decorator), A2AError if processing error is + found. + """ + try: + task_id = request.path_params['id'] + body = await request.body() + params = TaskPushNotificationConfig.validate_model(body) + config = await self.request_handler.on_set_task_push_notification_config( + params, context + ) + return MessageToJson( + proto_utils.ToProto.task_push_notification_config(config) + ) + except ServerError as e: + raise A2AError( + error=e.error if e.error else InternalError() + ) from e + + async def on_get_task( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> str: + """Handles the 'v1/tasks/{id}' REST method. + + Args: + request: The incoming `Request` object. + context: Context provided by the server. + + Returns: + A `Task` object containing the Task. + + Raises: + A2AError + """ + try: + task_id = request.path_params['id'] + historyLength = None + if 'historyLength' in request.query_params: + history_length = request.query_params['historyLength'] + params = TaskQueryParams(id=task_id, historyLength=historyLength) + task = await self.request_handler.on_get_task(params, context) + if task: + return MessageToJson(proto_utils.ToProto.task(task)) + raise ServerError(error=TaskNotFoundError()) + except ServerError as e: + raise A2AError( + id=request.id, error=e.error if e.error else InternalError() + ) from e + + async def list_push_notifications( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> list[TaskPushNotificationConfig]: + raise NotImplementedError("list notifications not implemented") + + async def list_tasks( + self, + request: Request, + context: ServerCallContext | None = None, + ) -> list[Task]: + raise NotImplementedError("list tasks not implemented") From 065dd38d2669f03175bb369dcad6981411f90797 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 11 Jul 2025 18:02:56 +0000 Subject: [PATCH 02/17] Address gemini comments --- src/a2a/server/apps/rest/fastapi_app.py | 8 ++-- src/a2a/server/apps/rest/rest_app.py | 42 +++++++++++-------- .../server/request_handlers/rest_handler.py | 19 +++++---- 3 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index c6356560..28922a33 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -70,14 +70,14 @@ def build( router = APIRouter() for route, callback in self._handler.routes().items(): router.add_api_route( - f'{rpc_url}{route}', - callback[0], - methods=[callback[1]] + f'{rpc_url}{route[0]}', + callback, + methods=[route[1]] ) @router.get(f'{rpc_url}{agent_card_url}') async def get_agent_card(request: Request) -> Response: - return await self._handle_get_agent_card(request) + return await self._handler._handle_get_agent_card(request) app.include_router(router) return app diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index b4b9d3ad..247d3f2e 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -26,8 +26,10 @@ from a2a.types import ( A2AError, AgentCard, + JSONParseError, UnsupportedOperationError, InternalError, + InvalidRequestError, ) from a2a.utils.errors import MethodNotImplementedError from a2a.server.apps.jsonrpc import ( @@ -139,7 +141,11 @@ async def event_generator( yield {'data': item} return EventSourceResponse(event_generator(method(request, call_context))) except Exception as e: - return self._handle_error(e) + # Since the stream has started, we can't return a JSONResponse. + # Instead, we runt the error handling logic (provides logging) + # and reraise the error and let server framework manage + self._handle_error(e) + raise e async def _handle_get_agent_card(self, request: Request) -> JSONResponse: @@ -177,48 +183,48 @@ async def handle_authenticated_agent_card(self, request: Request) -> JSONRespons self.agent_card.model_dump(mode='json', exclude_none=True) ) - def routes(self) -> dict[str, Tuple[Callable[[Request],Any], str]]: + def routes(self) -> dict[Tuple[str, str], Callable[[Request],Any]]: routes = { - '/v1/message:send': ( + ('/v1/message:send', 'POST'): ( functools.partial( self._handle_request, self.handler.on_message_send), - 'POST'), - '/v1/message:stream': ( + ), + ('/v1/message:stream', 'POST'): ( functools.partial( self._handle_streaming_request, self.handler.on_message_send_stream), - 'POST'), - '/v1/tasks/{id}:subscribe': ( + ), + ('/v1/tasks/{id}:subscribe', 'POST'): ( functools.partial( self._handle_streaming_request, self.handler.on_resubscribe_to_task), - 'POST'), - '/v1/tasks/{id}': ( + ), + ('/v1/tasks/{id}', 'GET'): ( functools.partial( self._handle_request, self.handler.on_get_task), - 'GET'), - '/v1/tasks/{id}/pushNotificationConfigs/{push_id}': ( + ), + ('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): ( functools.partial( self._handle_request, self.handler.get_push_notification), - 'GET'), - '/v1/tasks/{id}/pushNotificationConfigs': ( + ), + ('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): ( functools.partial( self._handle_request, self.handler.set_push_notification), - 'POST'), - '/v1/tasks/{id}/pushNotificationConfigs': ( + ), + ('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): ( functools.partial( self._handle_request, self.handler.list_push_notifications), - 'GET'), - '/v1/tasks': ( + ), + ('/v1/tasks', 'GET'): ( functools.partial( self._handle_request, self.handler.list_tasks), - 'GET'), + ), } if self.agent_card.supportsAuthenticatedExtendedCard: routes['/v1/card'] = ( diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 6f6167a5..6078180f 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -88,9 +88,9 @@ async def on_message_send( ) return MessageToJson(proto_utils.ToProto.task_or_message(task_or_message)) except ServerError as e: - return A2AError( + raise A2AError( error=e.error if e.error else InternalError() - ) + ) from e @validate( lambda self: self.agent_card.capabilities.streaming, @@ -214,12 +214,12 @@ async def get_push_notification( A2AError. """ try: - task_id = request.path_params['task_id'] + task_id = request.path_params['id'] push_id = request.path_params['push_id'] if push_id: params = GetTaskPushNotificationConfigParams(id=task_id, push_id=push_id) else: - params = TaskIdParams['task_id'] + params = TaskIdParams['id'] config = await self.request_handler.on_get_task_push_notification_config( params, context ) @@ -259,9 +259,14 @@ async def set_push_notification( try: task_id = request.path_params['id'] body = await request.body() + params = a2a_pb2.TaskPushNotificationConfig() + Parse(body, params) params = TaskPushNotificationConfig.validate_model(body) + a2a_request = proto_utils.FromProto.task_push_notification_config( + params, + ), config = await self.request_handler.on_set_task_push_notification_config( - params, context + a2a_request, context ) return MessageToJson( proto_utils.ToProto.task_push_notification_config(config) @@ -292,7 +297,7 @@ async def on_get_task( task_id = request.path_params['id'] historyLength = None if 'historyLength' in request.query_params: - history_length = request.query_params['historyLength'] + historyLength = request.query_params['historyLength'] params = TaskQueryParams(id=task_id, historyLength=historyLength) task = await self.request_handler.on_get_task(params, context) if task: @@ -300,7 +305,7 @@ async def on_get_task( raise ServerError(error=TaskNotFoundError()) except ServerError as e: raise A2AError( - id=request.id, error=e.error if e.error else InternalError() + error=e.error if e.error else InternalError() ) from e async def list_push_notifications( From 88dccabe03fc1c67ef0dc0054d9243a66da68623 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Fri, 11 Jul 2025 18:06:51 +0000 Subject: [PATCH 03/17] add to allowlist --- .github/actions/spelling/allow.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index d7595c1a..6979b7b7 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -71,5 +71,6 @@ sse tagwords taskupdate testuuid +Tful typeerror vulnz From 7e7eba17efb580a2bb9f14b65a758649bfb4a4d1 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 24 Jul 2025 04:32:28 +0000 Subject: [PATCH 04/17] Initial client refactor changes --- src/a2a/client/__init__.py | 45 +- src/a2a/client/client.py | 465 +++-------- src/a2a/client/client_factory.py | 113 +++ src/a2a/client/client_task_manager.py | 181 +++++ src/a2a/client/grpc_client.py | 226 +++++- src/a2a/client/jsonrpc_client.py | 719 +++++++++++++++++ src/a2a/client/rest_client.py | 762 ++++++++++++++++++ src/a2a/grpc/a2a_pb2.py | 4 +- src/a2a/server/apps/rest/fastapi_app.py | 4 + src/a2a/server/apps/rest/rest_app.py | 49 ++ .../server/request_handlers/rest_handler.py | 4 + src/a2a/utils/__init__.py | 3 +- src/a2a/utils/proto_utils.py | 48 +- src/a2a/utils/transports.py | 7 + 14 files changed, 2268 insertions(+), 362 deletions(-) create mode 100644 src/a2a/client/client_factory.py create mode 100644 src/a2a/client/client_task_manager.py create mode 100644 src/a2a/client/jsonrpc_client.py create mode 100644 src/a2a/client/rest_client.py create mode 100644 src/a2a/utils/transports.py diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 393d85ec..f0cfcca7 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -5,30 +5,67 @@ CredentialService, InMemoryContextCredentialStore, ) -from a2a.client.client import A2ACardResolver, A2AClient from a2a.client.errors import ( A2AClientError, A2AClientHTTPError, A2AClientJSONError, A2AClientTimeoutError, ) -from a2a.client.grpc_client import A2AGrpcClient +from a2a.client.jsonrpc_client import ( + JsonRpcClient, + JsonRpcTransportClient, + NewJsonRpcClient, +) +from a2a.client.grpc_client import ( + GrpcTransportClient, + GrpcClient, + NewGrpcClient, +) from a2a.client.helpers import create_text_message_object from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.client import ( + A2ACardResolver, + Client, + ClientConfig, + Consumer, + ClientEvent, +) +from a2a.client.client_factory import ( + ClientFactory, + ClientProducer, + minimal_agent_card +) +# For backward compatability define this alias. This will be deprecated in +# a future release. +A2AClient = JsonRpcTransportClient +A2AGrpcClient = GrpcTransportClient __all__ = [ 'A2ACardResolver', - 'A2AClient', 'A2AClientError', 'A2AClientHTTPError', 'A2AClientJSONError', 'A2AClientTimeoutError', - 'A2AGrpcClient', 'AuthInterceptor', 'ClientCallContext', 'ClientCallInterceptor', + 'Consumer', 'CredentialService', 'InMemoryContextCredentialStore', 'create_text_message_object', + 'A2AClient', # for backward compatability + 'A2AGrpcClient', # for backward compatability + 'Client', + 'ClientEvent', + 'ClientFactory', + 'ClientConfig', + 'ClientProducer', + 'GrpcTransportClient', + 'GrpcClient', + 'NewGrpcClient', + 'JsonRpcClient', + 'JsonRpcTransportClient', + 'NewJsonRpcClient', + 'minimal_agent_card', ] diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 66dfe0a4..5dfe0906 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -1,8 +1,10 @@ +import dataclasses import json import logging -from collections.abc import AsyncGenerator -from typing import Any +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import Any, Callable, Coroutine from uuid import uuid4 import httpx @@ -10,6 +12,18 @@ from httpx_sse import SSEError, aconnect_sse from pydantic import ValidationError +# Attempt to import the optional module +try: + from grpc.aio import Channel +except ImportError: + # If grpc.aio is not available, define a dummy type for type checking. + # This dummy type will only be used by type checkers. + if TYPE_CHECKING: + class Channel: # type: ignore[no-redef] + pass + else: + Channel = None # At runtime, pd will be None if the import failed. + from a2a.client.errors import ( A2AClientHTTPError, A2AClientJSONError, @@ -18,22 +32,17 @@ from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.types import ( AgentCard, - CancelTaskRequest, - CancelTaskResponse, - GetTaskPushNotificationConfigRequest, - GetTaskPushNotificationConfigResponse, - GetTaskRequest, - GetTaskResponse, - SendMessageRequest, - SendMessageResponse, - SendStreamingMessageRequest, - SendStreamingMessageResponse, - SetTaskPushNotificationConfigRequest, - SetTaskPushNotificationConfigResponse, -) -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, + GetTaskPushNotificationConfigParams, + Message, + PushNotificationConfig, + Task, + TaskIdParams, + TaskQueryParams, + TaskPushNotificationConfig, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, ) +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH from a2a.utils.telemetry import SpanKind, trace_class @@ -127,374 +136,136 @@ async def get_agent_card( return agent_card +@dataclasses.dataclass +class ClientConfig: + """Configuration class for the A2A Client Factory""" -@trace_class(kind=SpanKind.CLIENT) -class A2AClient: - """A2A Client for interacting with an A2A agent.""" + streaming: bool = True + """Whether client supports streaming""" - def __init__( - self, - httpx_client: httpx.AsyncClient, - agent_card: AgentCard | None = None, - url: str | None = None, - interceptors: list[ClientCallInterceptor] | None = None, - ): - """Initializes the A2AClient. + polling: bool = False + """Whether client prefers to poll for updates from message:send""" - Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. + httpx_client: httpx.AsyncClient | None = None + """Http client to use to connect to agent.""" - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. - url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. - interceptors: An optional list of client call interceptors to apply to requests. + grpc_channel_factory: Callable[[str], Channel] | None = None + """Generates a grpc connection channel for a given url.""" - Raises: - ValueError: If neither `agent_card` nor `url` is provided. - """ - if agent_card: - self.url = agent_card.url - elif url: - self.url = url - else: - raise ValueError('Must provide either agent_card or url') + supported_transports: list[str] = dataclasses.field(default_factory=list) + """Ordered list of transports for connecting to agent + (in order of preference). Empty implies JSONRPC only. - self.httpx_client = httpx_client - self.agent_card = agent_card - self.interceptors = interceptors or [] + This is a string type and not a Transports enum type to allow custom + transports to exist in closed ecosystems. + """ - async def _apply_interceptors( - self, - method_name: str, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Applies all registered interceptors to the request.""" - final_http_kwargs = http_kwargs or {} - final_request_payload = request_payload - - for interceptor in self.interceptors: - ( - final_request_payload, - final_http_kwargs, - ) = await interceptor.intercept( - method_name, - final_request_payload, - final_http_kwargs, - self.agent_card, - context, - ) - return final_request_payload, final_http_kwargs + use_client_preference: bool = False + """Whether to use client transport preferences over server preferences. + Recommended to use server preferences in most situations.""" - @staticmethod - async def get_client_from_agent_card_url( - httpx_client: httpx.AsyncClient, - base_url: str, - agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, - http_kwargs: dict[str, Any] | None = None, - ) -> 'A2AClient': - """Fetches the public AgentCard and initializes an A2A client. + acceptedOutputModes: list[str] = dataclasses.field(default_factory=list) + """The set of accepted output modes for the client.""" - This method will always fetch the public agent card. If an authenticated - or extended agent card is required, the A2ACardResolver should be used - directly to fetch the specific card, and then the A2AClient should be - instantiated with it. + pushNotificationConfigs: list[PushNotificationConfig] = dataclasses.field(default_factory=list) + """Push notification callbacks to use for every request.""" - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - base_url: The base URL of the agent's host. - agent_card_path: The path to the agent card endpoint, relative to the base URL. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.get request when fetching the agent card. +UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None +# Alias for emitted events from client +ClientEvent = tuple[Task, UpdateEvent] +# Alias for an event consuming callback. It takes either a (task, update) pair +# or a message as well as the agent card for the agent this came from. +Consumer = Callable[ + [ClientEvent | Message, AgentCard], Coroutine[None, Any, Any] +] - Returns: - An initialized `A2AClient` instance. - Raises: - A2AClientHTTPError: If an HTTP error occurs fetching the agent card. - A2AClientJSONError: If the agent card response is invalid. - """ - agent_card: AgentCard = await A2ACardResolver( - httpx_client, base_url=base_url, agent_card_path=agent_card_path - ).get_agent_card( - http_kwargs=http_kwargs - ) # Fetches public card by default - return A2AClient(httpx_client=httpx_client, agent_card=agent_card) +class Client(ABC): - async def send_message( + def __init__( self, - request: SendMessageRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> SendMessageResponse: - """Sends a non-streaming message request to the agent. - - Args: - request: The `SendMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. + consumers: list[Consumer] = [], + middleware: list[ClientCallInterceptor] = [], + ): + self._consumers = consumers or [] + self._middleware = middleware or [] - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/send', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SendMessageResponse.model_validate(response_data) - - async def send_message_streaming( + @abstractmethod + async def send_message( self, - request: SendStreamingMessageRequest, + request: Message, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> AsyncGenerator[SendStreamingMessageResponse]: - """Sends a streaming message request to the agent and yields responses as they arrive. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `SendStreamingMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - `SendStreamingMessageResponse` objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/stream', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'POST', - self.url, - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - yield SendStreamingMessageResponse.model_validate( - json.loads(sse.data) - ) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def _send_request( - self, - rpc_request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Sends a non-streaming JSON-RPC request to the agent. - - Args: - rpc_request_payload: JSON RPC payload for sending the request. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - - Returns: - The JSON response payload as a dictionary. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON. - """ - try: - response = await self.httpx_client.post( - self.url, json=rpc_request_payload, **(http_kwargs or {}) - ) - response.raise_for_status() - return response.json() - except httpx.ReadTimeout as e: - raise A2AClientTimeoutError('Client Request timed out') from e - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e + ) -> AsyncIterator[ClientEvent | Message]: + pass + yield + @abstractmethod async def get_task( self, - request: GetTaskRequest, + request: TaskQueryParams, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> GetTaskResponse: - """Retrieves the current state and history of a specific task. - - Args: - request: The `GetTaskRequest` object specifying the task ID and history length. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `GetTaskResponse` object containing the Task or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskResponse.model_validate(response_data) + ) -> Task: + pass + @abstractmethod async def cancel_task( self, - request: CancelTaskRequest, + request: TaskIdParams, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> CancelTaskResponse: - """Requests the agent to cancel a specific task. - - Args: - request: The `CancelTaskRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `CancelTaskResponse` object containing the updated Task with canceled status or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/cancel', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return CancelTaskResponse.model_validate(response_data) + ) -> Task: + pass + @abstractmethod async def set_task_callback( self, - request: SetTaskPushNotificationConfigRequest, + request: TaskPushNotificationConfig, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> SetTaskPushNotificationConfigResponse: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/set', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SetTaskPushNotificationConfigResponse.model_validate( - response_data - ) + ) -> TaskPushNotificationConfig: + pass + @abstractmethod async def get_task_callback( self, - request: GetTaskPushNotificationConfigRequest, + request: GetTaskPushNotificationConfigParams, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> GetTaskPushNotificationConfigResponse: - """Retrieves the push notification configuration for a specific task. + ) -> TaskPushNotificationConfig: + pass - Args: - request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. + @abstractmethod + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + pass + yield - Returns: - A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. + @abstractmethod + async def get_card( + self, + *, + context: ClientCallContext | None = None + ) -> AgentCard: + pass - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskPushNotificationConfigResponse.model_validate( - response_data - ) + async def add_event_consumer(self, consumer: Consumer): + self._consumers.append(consumer) + + async def add_request_middleware(self, middleware: ClientCallInterceptor): + self._middleware.append(middleware) + + async def consume( + self, + event: tuple[Task, UpdateEvent] | Message | None, + card: AgentCard, + ): + if not event: + return + for c in self._consumers: + await c(event, card) diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py new file mode 100644 index 00000000..f149d5d0 --- /dev/null +++ b/src/a2a/client/client_factory.py @@ -0,0 +1,113 @@ +from __future__ import annotations +import json +import logging + +from collections.abc import AsyncGenerator +from typing import Any, TYPE_CHECKING, Callable + +import httpx + +from httpx_sse import SSEError, aconnect_sse +from pydantic import ValidationError + +from a2a.utils import Transports + +from a2a.client.client import Client, ClientConfig, Consumer +from a2a.client.jsonrpc_client import NewJsonRpcClient +from a2a.client.grpc_client import NewGrpcClient +from a2a.client.rest_client import NewRestfulClient +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.types import ( + AgentCapabilities, + AgentCard, + Message, + Task, + TaskIdParams, + TaskQueryParams, + GetTaskPushNotificationConfigParams, + TaskPushNotificationConfig, +) + +logger = logging.getLogger(__name__) + +ClientProducer = Callable[ + [ + AgentCard | str, + ClientConfig, + list[Consumer], + list[ClientCallInterceptor] + ], + Client +] + +class ClientFactory: + + def __init__( + self, + config: ClientConfig, + consumers: list[Consumer], + ): + self._config = config + self._consumers = consumers + self._registry: dict[str, ClientProducer] = {} + if Transports.JSONRPC in self._config.supported_transports: + self._registry[Transports.JSONRPC] = NewJsonRpcClient + if Transports.RESTful in self._config.supported_transports: + self._registry[Transports.RESTful] = NewRestfulClient + if Transports.GRPC in self._config.supported_transports: + self._registry[Transports.GRPC] = NewGrpcClient + + def register(self, label: str, generator: ClientProducer): + self._registry[label] = generator + + def create( + self, + card: AgentCard, + consumers: list[Consumer] | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ) -> Client: + # Determine preferential transport + server_set = [card.preferredTransport or 'JSONRPC'] + if card.additionalInterfaces: + server_set.extend( + [x.transport for x in card.additionalInterfaces] + ) + client_set = self._config.supported_transports or ['JSONRPC'] + transport = None + if self._config.use_client_preference: + for x in client_set: + if x in server_set: + transport = x + break + else: + for x in server_set: + if x in client_set: + transport = x + break + if not transport: + raise Exception('no compatible transports found.') + if transport not in self._registry: + raise Exception(f'no client available for {transport}') + all_consumers = self._consumers + if consumers: + all_consumers.extend(consumers) + return self._registry[transport]( + card, self._config, all_consumers, interceptors + ) + +def minimal_agent_card(url: str, transports: list[str] = []) -> AgentCard: + """Generates a minimal card to simplify bootstrapping client creation.""" + return AgentCard( + url=url, + preferredTransport=transports[0] if transports else None, + additionalInterfaces=transports[1:] if len(transports) > 1 else [], + supportsAuthenticatedExtendedCard=True, + capabilities=AgentCapabilities(), + defaultInputModes=[], + defaultOutputModes=[], + description='', + skills=[], + version='', + name='', + ) diff --git a/src/a2a/client/client_task_manager.py b/src/a2a/client/client_task_manager.py new file mode 100644 index 00000000..733e51bd --- /dev/null +++ b/src/a2a/client/client_task_manager.py @@ -0,0 +1,181 @@ +import logging + +from a2a.server.events.event_queue import Event +from a2a.types import ( + InvalidParamsError, + Message, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.utils import append_artifact_to_task +from a2a.utils.errors import ServerError + + +logger = logging.getLogger(__name__) + + +class ClientTaskManager: + """Helps manage a task's lifecycle during execution of a request. + + Responsible for retrieving, saving, and updating the `Task` object based on + events received from the agent. + """ + + def __init__( + self, + ): + """Initializes the TaskManager. + + Args: + task_id: The ID of the task, if known from the request. + context_id: The ID of the context, if known from the request. + task_store: The `TaskStore` instance for persistence. + initial_message: The `Message` that initiated the task, if any. + Used when creating a new task object. + """ + self._current_task: Task | None = None + self._task_id: str | None = None + self._context_id: str | None = None + + def get_task(self) -> Task | None: + """Retrieves the current task object, either from memory or the store. + + If `task_id` is set, it first checks the in-memory `_current_task`, + then attempts to load it from the `task_store`. + + Returns: + The `Task` object if found, otherwise `None`. + """ + if not self._task_id: + logger.debug('task_id is not set, cannot get task.') + return None + + return self._current_task + + async def save_task_event( + self, event: Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ) -> Task | None: + """Processes a task-related event (Task, Status, Artifact) and saves the updated task state. + + Ensures task and context IDs match or are set from the event. + + Args: + event: The task-related event (`Task`, `TaskStatusUpdateEvent`, or `TaskArtifactUpdateEvent`). + + Returns: + The updated `Task` object after processing the event. + + Raises: + ServerError: If the task ID in the event conflicts with the TaskManager's ID + when the TaskManager's ID is already set. + """ + if isinstance(event, Task): + if self._current_task: + raise ClientError( + error=InvalidParamsError( + message="Task is already set, create new manager for new tasks." + ) + ) + await self._save_task(event) + return event + task_id_from_event = ( + event.id if isinstance(event, Task) else event.taskId + ) + if not self._task_id: + self._task_id = task_id_from_event + if not self._context_id: + self._context_id = event.contextId + + logger.debug( + 'Processing save of task event of type %s for task_id: %s', + type(event).__name__, + task_id_from_event, + ) + + task = self._current_task + if not task: + task = Task( + status=TaskStatus(state=TaskState.unknown), + id=task_id_from_event, + contextId=self._context_id if self._context_id else '', + ) + if isinstance(event, TaskStatusUpdateEvent): + logger.debug( + 'Updating task %s status to: %s', event.taskId, event.status.state + ) + if event.status.message: + if not task.history: + task.history = [event.status.message] + else: + task.history.append(event.status.message) + if event.metadata: + if not task.metadata: + task.metadata = {} + task.metadata.update(event.metadata) + task.status = event.status + else: + logger.debug('Appending artifact to task %s', task.id) + append_artifact_to_task(task, event) + self._current_task = task + return task + + async def process(self, event: Event) -> Event: + """Processes an event, updates the task state if applicable, stores it, and returns the event. + + If the event is task-related (`Task`, `TaskStatusUpdateEvent`, `TaskArtifactUpdateEvent`), + the internal task state is updated and persisted. + + Args: + event: The event object received from the agent. + + Returns: + The same event object that was processed. + """ + if isinstance( + event, Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ): + await self.save_task_event(event) + + return event + + async def _save_task(self, task: Task) -> None: + """Saves the given task to the task store and updates the in-memory `_current_task`. + + Args: + task: The `Task` object to save. + """ + logger.debug('Saving task with id: %s', task.id) + self._current_task = task + if not self._task_id: + logger.info('New task created with id: %s', task.id) + self._task_id = task.id + self._context_id = task.contextId + + def update_with_message(self, message: Message, task: Task) -> Task: + """Updates a task object in memory by adding a new message to its history. + + If the task has a message in its current status, that message is moved + to the history first. + + Args: + message: The new `Message` to add to the history. + task: The `Task` object to update. + + Returns: + The updated `Task` object (updated in-place). + """ + if task.status.message: + if task.history: + task.history.append(task.status.message) + else: + task.history = [task.status.message] + task.status.message = None + if task.history: + task.history.append(message) + else: + task.history = [message] + self._current_task = task + return task diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py index 5fc7cc99..4dabdc71 100644 --- a/src/a2a/client/grpc_client.py +++ b/src/a2a/client/grpc_client.py @@ -1,12 +1,22 @@ import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator import grpc +from a2a.client.client import ( + Client, + ClientCallContext, + ClientConfig, + Consumer, + ClientEvent, +) +from a2a.client.middleware import ClientCallInterceptor +from a2a.client.client_task_manager import ClientTaskManager from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, + GetTaskPushNotificationConfigParams, Message, MessageSendParams, Task, @@ -23,18 +33,18 @@ logger = logging.getLogger(__name__) -@trace_class(kind=SpanKind.CLIENT) -class A2AGrpcClient: - """A2A Client for interacting with an A2A agent via gRPC.""" +#@trace_class(kind=SpanKind.CLIENT) +class GrpcTransportClient: + """Transport specific details for interacting with an A2A agent via gRPC.""" def __init__( self, grpc_stub: a2a_pb2_grpc.A2AServiceStub, - agent_card: AgentCard, + agent_card: AgentCard | None, ): - """Initializes the A2AGrpcClient. + """Initializes the GrpcTransportClient. - Requires an `AgentCard` + Requires an `AgentCard` and a grpc `A2AServiceStub`. Args: grpc_stub: A grpc client stub. @@ -42,10 +52,18 @@ def __init__( """ self.agent_card = agent_card self.stub = grpc_stub + # If they don't provide an agent card, but do have a stub, lookup the + # card from the stub. + self._needs_extended_card = ( + agent_card.supportsAuthenticatedExtendedCard + if agent_card else True + ) async def send_message( self, request: MessageSendParams, + *, + context: ClientCallContext | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent. @@ -71,6 +89,8 @@ async def send_message( async def send_message_streaming( self, request: MessageSendParams, + *, + context: ClientCallContext | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -116,6 +136,8 @@ async def send_message_streaming( async def get_task( self, request: TaskQueryParams, + *, + context: ClientCallContext | None = None, ) -> Task: """Retrieves the current state and history of a specific task. @@ -133,6 +155,8 @@ async def get_task( async def cancel_task( self, request: TaskIdParams, + *, + context: ClientCallContext | None = None, ) -> Task: """Requests the agent to cancel a specific task. @@ -150,6 +174,8 @@ async def cancel_task( async def set_task_callback( self, request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task. @@ -173,6 +199,8 @@ async def set_task_callback( async def get_task_callback( self, request: TaskIdParams, # TODO: Update to a push id params + *, + context: ClientCallContext | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task. @@ -188,3 +216,187 @@ async def get_task_callback( ) ) return proto_utils.FromProto.task_push_notification_config(config) + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the authenticated card (if necessary) or the public one. + + Args: + context: The client call context. + + Returns: + A `AgentCard` object containing the card. + + Raises: + grpc.RpcError: If a gRPC error occurs during the request. + """ + # If we don't have the public card, try to get that first. + card = self.agent_card + + if not self._needs_extended_card: + return card + + card_pb = await self.stub.GetAgentCard( + a2a_pb2.GetAgentCardRequest(), + ) + card = proto_utils.FromProto.agent_card(card_pb) + self.agent_card = card + self._needs_extended_card = False + return card + + +#@trace_class(kind=SpanKind.CLIENT) +class GrpcClient(Client): + """GrpcClient provides the Client interface for the gRPC transport.""" + + def __init__( + self, + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], + ): + super().__init__(consumers, middleware) + if not config.grpc_channel_factory: + raise Exception('GRPC client requires channel factory.') + self._card = card + self._config = config + # Defer init to first use. + self._transport_client = None + channel = self._config.grpc_channel_factory(self._card.url) + stub = a2a_pb2_grpc.A2AServiceStub(channel) + self._transport_client = GrpcTransportClient(stub, self._card) + + async def send_message( + self, + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent | Message]: + # TODO: Set the request params from config + if not self._config.streaming or not self._card.capabilities.streaming: + print("Using blocking interaction") + response = await self._transport_client.send_message( + MessageSendParams( + message=request, + # TODO: set params + ), + context=context, + ) + result = ( + (response, None) if isinstance(response, Task) else response + ) + # Spin off consumers - in thread, out of thread, etc? + await self.consume(result, self._card) + yield result + return + # Get Task tracker + print("Using streaming interactions") + tracker = ClientTaskManager() + async for event in self._transport_client.send_message_streaming( + MessageSendParams( + message=request, + # TODO: set params + ), + context=context, + ): + # Update task, check for errors, etc. + if isinstance(event, Message): + await self.consume(event, self._card) + yield event + return + await tracker.process(event) + result = ( + tracker.get_task(), + None if isinstance(event, Task) else event + ) + await self.consume(result, self._card) + yield result + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.get_task( + request, + context=context, + ) + return response + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.cancel_task( + request, + context=context, + ) + return response + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.set_task_callback( + request, + context=context, + ) + return response + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.get_task_callback( + request, + context=context, + ) + return response + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + if not self._config.streaming or not self._card.capabilities.streaming: + raise Exception( + 'client and/or server do not support resubscription.' + ) + async for event in self._transport_client.resubscribe( + request, + context=context, + ): + # Update task, check for errors, etc. + yield event + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + card = await self._transport_client.get_card( + context=context, + ) + self._card = card + return card + + +def NewGrpcClient( + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor] +) -> Client: + return GrpcClient(card, config, consumers, middleware) diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py new file mode 100644 index 00000000..64035650 --- /dev/null +++ b/src/a2a/client/jsonrpc_client.py @@ -0,0 +1,719 @@ +import json +import logging + +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any +from uuid import uuid4 + +import httpx + +from httpx_sse import SSEError, aconnect_sse +from pydantic import ValidationError + +from a2a.client.client import Client, ClientConfig, A2ACardResolver, Consumer +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.client_task_manager import ClientTaskManager +from a2a.types import ( + AgentCard, + CancelTaskRequest, + CancelTaskResponse, + GetTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskRequest, + GetTaskResponse, + JSONRPCErrorResponse, + Message, + MessageSendParams, + SendMessageRequest, + SendMessageResponse, + SendStreamingMessageRequest, + SendStreamingMessageResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + Task, + TaskIdParams, + TaskQueryParams, + TaskPushNotificationConfig, + TaskResubscriptionRequest, +) +from a2a.utils.constants import ( + AGENT_CARD_WELL_KNOWN_PATH, +) +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class JsonRpcTransportClient: + """A2A Client for interacting with an A2A agent.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the A2AClient. + + Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. + + Args: + httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). + agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. + url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. + interceptors: An optional list of client call interceptors to apply to requests. + + Raises: + ValueError: If neither `agent_card` nor `url` is provided. + """ + if agent_card: + self.url = agent_card.url + elif url: + self.url = url + else: + raise ValueError('Must provide either agent_card or url') + + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + # Indicate if we have captured an extended card details so we can update + # on first call if needed. It is done this way so the caller can setup + # their auth credentials based on the public card and get the updated + # card. + self._needs_extended_card = ( + not agent_card.supportsAuthenticatedExtendedCard + if agent_card else True) + + async def _apply_interceptors( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Applies all registered interceptors to the request.""" + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + + for interceptor in self.interceptors: + ( + final_request_payload, + final_http_kwargs, + ) = await interceptor.intercept( + method_name, + final_request_payload, + final_http_kwargs, + self.agent_card, + context, + ) + return final_request_payload, final_http_kwargs + + @staticmethod + async def get_client_from_agent_card_url( + httpx_client: httpx.AsyncClient, + base_url: str, + agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, + http_kwargs: dict[str, Any] | None = None, + ) -> 'A2AClient': + """[deprecated] Fetches the public AgentCard and initializes an A2A client. + + This method will always fetch the public agent card. If an authenticated + or extended agent card is required, the A2ACardResolver should be used + directly to fetch the specific card, and then the A2AClient should be + instantiated with it. + + Args: + httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). + base_url: The base URL of the agent's host. + agent_card_path: The path to the agent card endpoint, relative to the base URL. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.get request when fetching the agent card. + + Returns: + An initialized `A2AClient` instance. + + Raises: + A2AClientHTTPError: If an HTTP error occurs fetching the agent card. + A2AClientJSONError: If the agent card response is invalid. + """ + agent_card: AgentCard = await A2ACardResolver( + httpx_client, base_url=base_url, agent_card_path=agent_card_path + ).get_agent_card( + http_kwargs=http_kwargs + ) # Fetches public card by default + return A2AClient(httpx_client=httpx_client, agent_card=agent_card) + + async def send_message( + self, + request: SendMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SendMessageResponse: + """Sends a non-streaming message request to the agent. + + Args: + request: The `SendMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'message/send', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return SendMessageResponse.model_validate(response_data) + + async def send_message_streaming( + self, + request: SendStreamingMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `SendStreamingMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'message/stream', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + yield SendStreamingMessageResponse.model_validate( + json.loads(sse.data) + ) + except SSEError as e: + raise A2AClientHTTPError( + 400, + f'Invalid SSE response or protocol error: {e}', + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_request( + self, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Sends a non-streaming JSON-RPC request to the agent. + + Args: + rpc_request_payload: JSON RPC payload for sending the request. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + + Returns: + The JSON response payload as a dictionary. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON. + """ + try: + response = await self.httpx_client.post( + self.url, json=rpc_request_payload, **(http_kwargs or {}) + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_task( + self, + request: GetTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskResponse: + """Retrieves the current state and history of a specific task. + + Args: + request: The `GetTaskRequest` object specifying the task ID and history length. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskResponse` object containing the Task or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/get', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return GetTaskResponse.model_validate(response_data) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> CancelTaskResponse: + """Requests the agent to cancel a specific task. + + Args: + request: The `CancelTaskRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `CancelTaskResponse` object containing the updated Task with canceled status or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/cancel', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return CancelTaskResponse.model_validate(response_data) + + async def set_task_callback( + self, + request: SetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SetTaskPushNotificationConfigResponse: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/set', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context + ) + response_data = await self._send_request(payload, modified_kwargs) + return SetTaskPushNotificationConfigResponse.model_validate( + response_data + ) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskPushNotificationConfigResponse: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/get', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return GetTaskPushNotificationConfigResponse.model_validate( + response_data + ) + + async def resubscribe( + self, + request: TaskResubscriptionRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse]: + """Reconnects to get task updates + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `TaskResubscriptionRequest` object containing the task information to reconnect to. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/resubscribe', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + yield SendStreamingMessageResponse.model_validate( + json.loads(sse.data) + ) + except SSEError as e: + raise A2AClientHTTPError( + 400, + f'Invalid SSE response or protocol error: {e}', + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_card( + self, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the authenticated card (if necessary) or the public one. + + Args: + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `AgentCard` object containing the card or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + # If we don't have the public card, try to get that first. + card = self.card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = resolver.get_agent_card(http_kwargs=http_kwargs) + self._needs_extended_card = card.supportsAuthenticatedExtendedCard + self.card = card + + if not self._needs_extended_card: + return card + + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'card/getAuthenticated', + '', + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + card = AgentCard.model_validate(response_data) + self.card = card + self._needs_extended_card = False + return card + + +@trace_class(kind=SpanKind.CLIENT) +class JsonRpcClient(Client): + """JsonRpcClient is the implementation of the JSONRPC A2A client. + + This client proxies requests to the JsonRpcTransportClient implementation + and manages the JSONRPC specific details. If passing additional arguements + in the http.post command, these should be attached to the ClientCallContext + under the dictionary key 'http_kwargs'. + """ + + def __init__( + self, + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], + ): + super().__init__(consumers, middleware) + if not config.httpx_client: + raise Exception('JsonRpc client requires httpx client.') + self._card = card + url = card.url + self._config = config + self._transport_client = JsonRpcTransportClient( + config.httpx_client, self._card, url, middleware + ) + + def get_http_args( + self, context: ClientCallContext + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs', None) if context else None + + async def send_message( + self, + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + # TODO: Set the request params from config + if not self._config.streaming or not self._card.capabilities.streaming: + response = await self._transport_client.send_message( + SendMessageRequest( + params=MessageSendParams( + message=request, + ), + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise response.root.error + result = response.root.result + result = result if isinstance(result, Message) else (result, None) + await self.consume(result, self._card) + yield result + return + tracker = ClientTaskManager() + async for event in self._transport_client.send_message_streaming( + SendStreamingMessageRequest( + params=MessageSendParams( + message=request, + ), + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ): + if isinstance(event.root, JSONRPCErrorResponse): + raise event.root.error + result = event.root.result + # Update task, check for errors, etc. + if isinstance(result, Message): + yield result + return + await tracker.process(result) + result = ( + tracker.get_task(), + None if isinstance(result, Task) + else result + ) + await self.consume(result, self._card) + yield result + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.get_task( + GetTaskRequest( + params=request, + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + return response.result + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.cancel_task( + CancelTaskRequest( + params=request, + id=srt(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + return response.result + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.set_task_callback( + SetTaskPushNotificationConfigRequest( + params=request, + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + return response.result + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.get_task_callback( + GetTaskPushNotificationConfigRequest( + params=request, + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + return response.result + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + if not self._config.streaming or not self._card.capabilities.streaming: + raise Exception( + 'client and/or server do not support resubscription.' + ) + async for event in self._transport_client.resubscribe( + TaskResubscriptionRequest( + params=TaskIdParams, + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ): + # Update task, check for errors, etc. + yield event + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + return await self._transport_client.get_card( + http_kwargs=self.get_http_args(context), + context=context, + ) + +def NewJsonRpcClient( + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor] +) -> Client: + return JsonRpcClient(card, config, consumers, middleware) diff --git a/src/a2a/client/rest_client.py b/src/a2a/client/rest_client.py new file mode 100644 index 00000000..51c29383 --- /dev/null +++ b/src/a2a/client/rest_client.py @@ -0,0 +1,762 @@ +import json +import logging + +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any +from uuid import uuid4 + +import httpx + +from httpx_sse import SSEError, aconnect_sse +from pydantic import ValidationError + +from a2a.client.client import Client, ClientConfig, A2ACardResolver, Consumer +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.client_task_manager import ClientTaskManager +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskQueryParams, + TaskPushNotificationConfig, + TaskStatusUpdateEvent, +) +from a2a.utils.constants import ( + AGENT_CARD_WELL_KNOWN_PATH, +) +from a2a.utils.telemetry import SpanKind, trace_class +from a2a.grpc import a2a_pb2 +from a2a.utils import proto_utils +from google.protobuf.json_format import Parse, MessageToDict + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class RestTransportClient: + """A2A Client for interacting with an A2A agent.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the A2AClient. + + Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. + + Args: + httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). + agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. + url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. + interceptors: An optional list of client call interceptors to apply to requests. + + Raises: + ValueError: If neither `agent_card` nor `url` is provided. + """ + if agent_card: + self.url = agent_card.url + elif url: + self.url = url + else: + raise ValueError('Must provide either agent_card or url') + # If the url ends in / remove it as this is added by the routes + if self.url.endswith("/"): + self.url = self.url[:-1] + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + # Indicate if we have captured an extended card details so we can update + # on first call if needed. It is done this way so the caller can setup + # their auth credentials based on the public card and get the updated + # card. + self._needs_extended_card = ( + not agent_card.supportsAuthenticatedExtendedCard + if agent_card else True) + + async def _apply_interceptors( + self, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Applies all registered interceptors to the request.""" + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + # TODO: Implement interceptors for other transports + return final_request_payload, final_http_kwargs + + @staticmethod + async def get_client_from_agent_card_url( + httpx_client: httpx.AsyncClient, + base_url: str, + agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, + http_kwargs: dict[str, Any] | None = None, + ) -> 'A2AClient': + """[deprecated] Fetches the public AgentCard and initializes an A2A client. + + This method will always fetch the public agent card. If an authenticated + or extended agent card is required, the A2ACardResolver should be used + directly to fetch the specific card, and then the A2AClient should be + instantiated with it. + + Args: + httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). + base_url: The base URL of the agent's host. + agent_card_path: The path to the agent card endpoint, relative to the base URL. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.get request when fetching the agent card. + + Returns: + An initialized `A2AClient` instance. + + Raises: + A2AClientHTTPError: If an HTTP error occurs fetching the agent card. + A2AClientJSONError: If the agent card response is invalid. + """ + agent_card: AgentCard = await A2ACardResolver( + httpx_client, base_url=base_url, agent_card_path=agent_card_path + ).get_agent_card( + http_kwargs=http_kwargs + ) # Fetches public card by default + return A2AClient(httpx_client=httpx_client, agent_card=agent_card) + + async def send_message( + self, + request: MessageSendParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `Task` or `Message` object containing the agent's response. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.send_message_config( + request.config + ), + metadata=( + proto_utils.ToProto.metadata(request.metadata) + if request.metadata else None + ), + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + response_data = await self._send_post_request( + '/v1/message:send', + payload, + modified_kwargs + ) + response_pb = a2a_pb2.SendMessageResponse() + Parse(response_data, response_pb) + return proto_utils.FromProto.task_or_message(response_pb) + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + Objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=( + proto_utils.ToProto.metadata(request.metadata) + if request.metadata else None + ), + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + f'{self.url}/v1/message:stream', + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, + f'Invalid SSE response or protocol error: {e}', + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_post_request( + self, + target: str, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Sends a non-streaming JSON-RPC request to the agent. + + Args: + target: url path + rpc_request_payload: JSON payload for sending the request. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + + Returns: + The JSON response payload as a dictionary. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON. + """ + try: + response = await self.httpx_client.post( + f'{self.url}{target}', + json=rpc_request_payload, + **(http_kwargs or {}) + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_get_request( + self, + target: str, + query_params: dict[str, str], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Sends a non-streaming JSON-RPC request to the agent. + + Args: + target: url path + query_params: HTTP query params for the request. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + + Returns: + The JSON response payload as a dictionary. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON. + """ + try: + response = await self.httpx_client.get( + f'{self.url}{target}', + params=query_params, + **(http_kwargs or {}) + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_task( + self, + request: TaskQueryParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task. + + Args: + request: The `TaskQueryParams` object specifying the task ID and history length. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `Task` object containing the Task. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + # Apply interceptors before sending - only for the http kwargs + payload, modified_kwargs = await self._apply_interceptors( + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.taskId}', + { + 'historyLength': request.historyLength + } if request.historyLength else {}, + modified_kwargs + ) + task = a2a_pb2.Task() + Parse(response_data, task) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `Task` object containing the updated Task with canceled status + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.CancelTaskRequest( + name=f'tasks/{request.id}' + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.taskId}:cancel', + payload, + modified_kwargs + ) + task = a2a_pb2.Task() + Parse(response_data, task) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `TaskPushNotificationConfig` object containing the confirmation. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( + parent=f'tasks/{request.taskId}', + config_id=request.pushNotificationConfig.id, + config=proto_utils.ToProto.push_notification_config( + request.pushNotificationConfig + ), + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.taskId}/pushNotificationConfigs/', + payload, + modified_kwargs + ) + config = a2a_pb2.TaskPushNotificationConfig() + Parse(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigParams` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `TaskPushNotificationConfig` object containing the configuration. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.GetTaskPushNotificationConfigRequest( + name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.taskId}/pushNotificationConfigs/{request.pushNotificationId}', + {}, + modified_kwargs + ) + config = a2a_pb2.TaskPushNotificationConfig() + Parse(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def resubscribe( + self, + request: TaskIdParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message]: + """Reconnects to get task updates + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `TaskIdParams` object containing the task information to reconnect to. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + Objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.TaskSubscriptionRequest( + name=f'tasks/{request.id}', + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + f'{self.url}/v1/tasks/{request.taskId}:subscribe', + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, + f'Invalid SSE response or protocol error: {e}', + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_card( + self, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the authenticated card (if necessary) or the public one. + + Args: + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `AgentCard` object containing the card or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + # If we don't have the public card, try to get that first. + card = self.card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = resolver.get_agent_card(http_kwargs=http_kwargs) + self._needs_extended_card = card.supportsAuthenticatedExtendedCard + self.card = card + + if not self._needs_extended_card: + return card + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + '', + http_kwargs, + context, + ) + response_data = await self._send_get_request( + '/v1/card/get', + {}, + modified_kwargs) + card = AgentCard.model_validate(response_data) + self.card = card + self._needs_extended_card = False + return card + + +@trace_class(kind=SpanKind.CLIENT) +class RestClient(Client): + """RestClient is the implementation of the RESTful A2A client. + + This client proxies requests to the RestTransportClient implementation + and manages the REST specific details. If passing additional arguments + in the http.post command, these should be attached to the ClientCallContext + under the dictionary key 'http_kwargs'. + """ + + def __init__( + self, + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], + ): + super().__init__(consumers, middleware) + if not config.httpx_client: + raise Exception('JsonRpc client requires httpx client.') + self._card = card + url = card.url + self._config = config + self._transport_client = RestTransportClient( + config.httpx_client, self._card, url, middleware + ) + + def get_http_args( + self, context: ClientCallContext + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs', None) if context else None + + async def send_message( + self, + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + # TODO: Set the request params from config + if not self._config.streaming or not self._card.capabilities.streaming: + response = await self._transport_client.send_message( + MessageSendParams( + message=request, + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + result = ( + response + if isinstance(response, Message) + else (response, None) + ) + await self.consume(result, self._card) + yield result + return + tracker = ClientTaskManager() + async for event in self._transport_client.send_message_streaming( + MessageSendParams( + message=request, + ), + http_kwargs=self.get_http_args(context), + context=context, + ): + # Update task, check for errors, etc. + if isinstance(event, Message): + yield result + return + await tracker.process(event) + result = ( + tracker.get_task(), + None if isinstance(event, Task) + else event + ) + await self.consume(result, self._card) + yield result + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.get_task( + request, + http_kwargs=self.get_http_args(context), + context=context, + ) + return response + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.cancel_task( + request, + http_kwargs=self.get_http_args(context), + context=context, + ) + return response + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.set_task_callback( + request, + http_kwargs=self.get_http_args(context), + context=context, + ) + return response + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.get_task_callback( + request, + http_kwargs=self.get_http_args(context), + context=context, + ) + return response.result + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + if not self._config.streaming or not self._card.capabilities.streaming: + raise Exception( + 'client and/or server do not support resubscription.' + ) + async for event in self._transport_client.resubscribe( + TaskIdParams, + http_kwargs=self.get_http_args(context), + context=context, + ): + # Update task, check for errors, etc. + yield event + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + return await self._transport_client.get_card( + http_kwargs=self.get_http_args(context), + context=context, + ) + +def NewRestfulClient( + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor] +) -> Client: + return RestClient(card, config, consumers, middleware) diff --git a/src/a2a/grpc/a2a_pb2.py b/src/a2a/grpc/a2a_pb2.py index e11d6ebf..e8304632 100644 --- a/src/a2a/grpc/a2a_pb2.py +++ b/src/a2a/grpc/a2a_pb2.py @@ -30,14 +30,14 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ta2a.proto\x12\x06\x61\x32\x61.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xde\x01\n\x18SendMessageConfiguration\x12\x32\n\x15\x61\x63\x63\x65pted_output_modes\x18\x01 \x03(\tR\x13\x61\x63\x63\x65ptedOutputModes\x12K\n\x11push_notification\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x10pushNotification\x12%\n\x0ehistory_length\x18\x03 \x01(\x05R\rhistoryLength\x12\x1a\n\x08\x62locking\x18\x04 \x01(\x08R\x08\x62locking\"\xf1\x01\n\x04Task\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12.\n\tartifacts\x18\x04 \x03(\x0b\x32\x10.a2a.v1.ArtifactR\tartifacts\x12)\n\x07history\x18\x05 \x03(\x0b\x32\x0f.a2a.v1.MessageR\x07history\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x99\x01\n\nTaskStatus\x12\'\n\x05state\x18\x01 \x01(\x0e\x32\x11.a2a.v1.TaskStateR\x05state\x12(\n\x06update\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageR\x07message\x12\x38\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\"t\n\x04Part\x12\x14\n\x04text\x18\x01 \x01(\tH\x00R\x04text\x12&\n\x04\x66ile\x18\x02 \x01(\x0b\x32\x10.a2a.v1.FilePartH\x00R\x04\x66ile\x12&\n\x04\x64\x61ta\x18\x03 \x01(\x0b\x32\x10.a2a.v1.DataPartH\x00R\x04\x64\x61taB\x06\n\x04part\"\x7f\n\x08\x46ilePart\x12$\n\rfile_with_uri\x18\x01 \x01(\tH\x00R\x0b\x66ileWithUri\x12(\n\x0f\x66ile_with_bytes\x18\x02 \x01(\x0cH\x00R\rfileWithBytes\x12\x1b\n\tmime_type\x18\x03 \x01(\tR\x08mimeTypeB\x06\n\x04\x66ile\"7\n\x08\x44\x61taPart\x12+\n\x04\x64\x61ta\x18\x01 \x01(\x0b\x32\x17.google.protobuf.StructR\x04\x64\x61ta\"\xff\x01\n\x07Message\x12\x1d\n\nmessage_id\x18\x01 \x01(\tR\tmessageId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12\x17\n\x07task_id\x18\x03 \x01(\tR\x06taskId\x12 \n\x04role\x18\x04 \x01(\x0e\x32\x0c.a2a.v1.RoleR\x04role\x12&\n\x07\x63ontent\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x07\x63ontent\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xda\x01\n\x08\x41rtifact\x12\x1f\n\x0b\x61rtifact_id\x18\x01 \x01(\tR\nartifactId\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x04 \x01(\tR\x0b\x64\x65scription\x12\"\n\x05parts\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x05parts\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xc6\x01\n\x15TaskStatusUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12\x14\n\x05\x66inal\x18\x04 \x01(\x08R\x05\x66inal\x12\x33\n\x08metadata\x18\x05 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\xeb\x01\n\x17TaskArtifactUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12,\n\x08\x61rtifact\x18\x03 \x01(\x0b\x32\x10.a2a.v1.ArtifactR\x08\x61rtifact\x12\x16\n\x06\x61ppend\x18\x04 \x01(\x08R\x06\x61ppend\x12\x1d\n\nlast_chunk\x18\x05 \x01(\x08R\tlastChunk\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x94\x01\n\x16PushNotificationConfig\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x10\n\x03url\x18\x02 \x01(\tR\x03url\x12\x14\n\x05token\x18\x03 \x01(\tR\x05token\x12\x42\n\x0e\x61uthentication\x18\x04 \x01(\x0b\x32\x1a.a2a.v1.AuthenticationInfoR\x0e\x61uthentication\"P\n\x12\x41uthenticationInfo\x12\x18\n\x07schemes\x18\x01 \x03(\tR\x07schemes\x12 \n\x0b\x63redentials\x18\x02 \x01(\tR\x0b\x63redentials\"@\n\x0e\x41gentInterface\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\x1c\n\ttransport\x18\x02 \x01(\tR\ttransport\"\xf1\x06\n\tAgentCard\x12)\n\x10protocol_version\x18\x10 \x01(\tR\x0fprotocolVersion\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x10\n\x03url\x18\x03 \x01(\tR\x03url\x12/\n\x13preferred_transport\x18\x0e \x01(\tR\x12preferredTransport\x12K\n\x15\x61\x64\x64itional_interfaces\x18\x0f \x03(\x0b\x32\x16.a2a.v1.AgentInterfaceR\x14\x61\x64\x64itionalInterfaces\x12\x31\n\x08provider\x18\x04 \x01(\x0b\x32\x15.a2a.v1.AgentProviderR\x08provider\x12\x18\n\x07version\x18\x05 \x01(\tR\x07version\x12+\n\x11\x64ocumentation_url\x18\x06 \x01(\tR\x10\x64ocumentationUrl\x12=\n\x0c\x63\x61pabilities\x18\x07 \x01(\x0b\x32\x19.a2a.v1.AgentCapabilitiesR\x0c\x63\x61pabilities\x12Q\n\x10security_schemes\x18\x08 \x03(\x0b\x32&.a2a.v1.AgentCard.SecuritySchemesEntryR\x0fsecuritySchemes\x12,\n\x08security\x18\t \x03(\x0b\x32\x10.a2a.v1.SecurityR\x08security\x12.\n\x13\x64\x65\x66\x61ult_input_modes\x18\n \x03(\tR\x11\x64\x65\x66\x61ultInputModes\x12\x30\n\x14\x64\x65\x66\x61ult_output_modes\x18\x0b \x03(\tR\x12\x64\x65\x66\x61ultOutputModes\x12*\n\x06skills\x18\x0c \x03(\x0b\x32\x12.a2a.v1.AgentSkillR\x06skills\x12O\n$supports_authenticated_extended_card\x18\r \x01(\x08R!supportsAuthenticatedExtendedCard\x1aZ\n\x14SecuritySchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x16.a2a.v1.SecuritySchemeR\x05value:\x02\x38\x01\"E\n\rAgentProvider\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\"\n\x0corganization\x18\x02 \x01(\tR\x0corganization\"\x98\x01\n\x11\x41gentCapabilities\x12\x1c\n\tstreaming\x18\x01 \x01(\x08R\tstreaming\x12-\n\x12push_notifications\x18\x02 \x01(\x08R\x11pushNotifications\x12\x36\n\nextensions\x18\x03 \x03(\x0b\x32\x16.a2a.v1.AgentExtensionR\nextensions\"\x91\x01\n\x0e\x41gentExtension\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08required\x18\x03 \x01(\x08R\x08required\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.StructR\x06params\"\xc6\x01\n\nAgentSkill\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x03 \x01(\tR\x0b\x64\x65scription\x12\x12\n\x04tags\x18\x04 \x03(\tR\x04tags\x12\x1a\n\x08\x65xamples\x18\x05 \x03(\tR\x08\x65xamples\x12\x1f\n\x0binput_modes\x18\x06 \x03(\tR\ninputModes\x12!\n\x0coutput_modes\x18\x07 \x03(\tR\x0boutputModes\"\x8a\x01\n\x1aTaskPushNotificationConfig\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12X\n\x18push_notification_config\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x16pushNotificationConfig\" \n\nStringList\x12\x12\n\x04list\x18\x01 \x03(\tR\x04list\"\x93\x01\n\x08Security\x12\x37\n\x07schemes\x18\x01 \x03(\x0b\x32\x1d.a2a.v1.Security.SchemesEntryR\x07schemes\x1aN\n\x0cSchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x12.a2a.v1.StringListR\x05value:\x02\x38\x01\"\x91\x03\n\x0eSecurityScheme\x12U\n\x17\x61pi_key_security_scheme\x18\x01 \x01(\x0b\x32\x1c.a2a.v1.APIKeySecuritySchemeH\x00R\x14\x61piKeySecurityScheme\x12[\n\x19http_auth_security_scheme\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.HTTPAuthSecuritySchemeH\x00R\x16httpAuthSecurityScheme\x12T\n\x16oauth2_security_scheme\x18\x03 \x01(\x0b\x32\x1c.a2a.v1.OAuth2SecuritySchemeH\x00R\x14oauth2SecurityScheme\x12k\n\x1fopen_id_connect_security_scheme\x18\x04 \x01(\x0b\x32#.a2a.v1.OpenIdConnectSecuritySchemeH\x00R\x1bopenIdConnectSecuritySchemeB\x08\n\x06scheme\"h\n\x14\x41PIKeySecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08location\x18\x02 \x01(\tR\x08location\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\"w\n\x16HTTPAuthSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x16\n\x06scheme\x18\x02 \x01(\tR\x06scheme\x12#\n\rbearer_format\x18\x03 \x01(\tR\x0c\x62\x65\x61rerFormat\"b\n\x14OAuth2SecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12(\n\x05\x66lows\x18\x02 \x01(\x0b\x32\x12.a2a.v1.OAuthFlowsR\x05\x66lows\"n\n\x1bOpenIdConnectSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12-\n\x13open_id_connect_url\x18\x02 \x01(\tR\x10openIdConnectUrl\"\xb0\x02\n\nOAuthFlows\x12S\n\x12\x61uthorization_code\x18\x01 \x01(\x0b\x32\".a2a.v1.AuthorizationCodeOAuthFlowH\x00R\x11\x61uthorizationCode\x12S\n\x12\x63lient_credentials\x18\x02 \x01(\x0b\x32\".a2a.v1.ClientCredentialsOAuthFlowH\x00R\x11\x63lientCredentials\x12\x37\n\x08implicit\x18\x03 \x01(\x0b\x32\x19.a2a.v1.ImplicitOAuthFlowH\x00R\x08implicit\x12\x37\n\x08password\x18\x04 \x01(\x0b\x32\x19.a2a.v1.PasswordOAuthFlowH\x00R\x08passwordB\x06\n\x04\x66low\"\x8a\x02\n\x1a\x41uthorizationCodeOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1b\n\ttoken_url\x18\x02 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x03 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x04 \x03(\x0b\x32..a2a.v1.AuthorizationCodeOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdd\x01\n\x1a\x43lientCredentialsOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x03 \x03(\x0b\x32..a2a.v1.ClientCredentialsOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdb\x01\n\x11ImplicitOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.ImplicitOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xcb\x01\n\x11PasswordOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.PasswordOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xc1\x01\n\x12SendMessageRequest\x12.\n\x07request\x18\x01 \x01(\x0b\x32\x0f.a2a.v1.MessageB\x03\xe0\x41\x02R\x07request\x12\x46\n\rconfiguration\x18\x02 \x01(\x0b\x32 .a2a.v1.SendMessageConfigurationR\rconfiguration\x12\x33\n\x08metadata\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"P\n\x0eGetTaskRequest\x12\x17\n\x04name\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x04name\x12%\n\x0ehistory_length\x18\x02 \x01(\x05R\rhistoryLength\"\'\n\x11\x43\x61ncelTaskRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\":\n$GetTaskPushNotificationConfigRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"=\n\'DeleteTaskPushNotificationConfigRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"\xa9\x01\n\'CreateTaskPushNotificationConfigRequest\x12\x1b\n\x06parent\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x06parent\x12 \n\tconfig_id\x18\x02 \x01(\tB\x03\xe0\x41\x02R\x08\x63onfigId\x12?\n\x06\x63onfig\x18\x03 \x01(\x0b\x32\".a2a.v1.TaskPushNotificationConfigB\x03\xe0\x41\x02R\x06\x63onfig\"-\n\x17TaskSubscriptionRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"{\n%ListTaskPushNotificationConfigRequest\x12\x16\n\x06parent\x18\x01 \x01(\tR\x06parent\x12\x1b\n\tpage_size\x18\x02 \x01(\x05R\x08pageSize\x12\x1d\n\npage_token\x18\x03 \x01(\tR\tpageToken\"\x15\n\x13GetAgentCardRequest\"m\n\x13SendMessageResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12\'\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x07messageB\t\n\x07payload\"\xfa\x01\n\x0eStreamResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12\'\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x07message\x12\x44\n\rstatus_update\x18\x03 \x01(\x0b\x32\x1d.a2a.v1.TaskStatusUpdateEventH\x00R\x0cstatusUpdate\x12J\n\x0f\x61rtifact_update\x18\x04 \x01(\x0b\x32\x1f.a2a.v1.TaskArtifactUpdateEventH\x00R\x0e\x61rtifactUpdateB\t\n\x07payload\"\x8e\x01\n&ListTaskPushNotificationConfigResponse\x12<\n\x07\x63onfigs\x18\x01 \x03(\x0b\x32\".a2a.v1.TaskPushNotificationConfigR\x07\x63onfigs\x12&\n\x0fnext_page_token\x18\x02 \x01(\tR\rnextPageToken*\xfa\x01\n\tTaskState\x12\x1a\n\x16TASK_STATE_UNSPECIFIED\x10\x00\x12\x18\n\x14TASK_STATE_SUBMITTED\x10\x01\x12\x16\n\x12TASK_STATE_WORKING\x10\x02\x12\x18\n\x14TASK_STATE_COMPLETED\x10\x03\x12\x15\n\x11TASK_STATE_FAILED\x10\x04\x12\x18\n\x14TASK_STATE_CANCELLED\x10\x05\x12\x1d\n\x19TASK_STATE_INPUT_REQUIRED\x10\x06\x12\x17\n\x13TASK_STATE_REJECTED\x10\x07\x12\x1c\n\x18TASK_STATE_AUTH_REQUIRED\x10\x08*;\n\x04Role\x12\x14\n\x10ROLE_UNSPECIFIED\x10\x00\x12\r\n\tROLE_USER\x10\x01\x12\x0e\n\nROLE_AGENT\x10\x02\x32\xba\n\n\nA2AService\x12\x63\n\x0bSendMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x1b.a2a.v1.SendMessageResponse\"\x1b\x82\xd3\xe4\x93\x02\x15\"\x10/v1/message:send:\x01*\x12k\n\x14SendStreamingMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x16.a2a.v1.StreamResponse\"\x1d\x82\xd3\xe4\x93\x02\x17\"\x12/v1/message:stream:\x01*0\x01\x12R\n\x07GetTask\x12\x16.a2a.v1.GetTaskRequest\x1a\x0c.a2a.v1.Task\"!\xda\x41\x04name\x82\xd3\xe4\x93\x02\x14\x12\x12/v1/{name=tasks/*}\x12[\n\nCancelTask\x12\x19.a2a.v1.CancelTaskRequest\x1a\x0c.a2a.v1.Task\"$\x82\xd3\xe4\x93\x02\x1e\"\x19/v1/{name=tasks/*}:cancel:\x01*\x12s\n\x10TaskSubscription\x12\x1f.a2a.v1.TaskSubscriptionRequest\x1a\x16.a2a.v1.StreamResponse\"$\x82\xd3\xe4\x93\x02\x1e\x12\x1c/v1/{name=tasks/*}:subscribe0\x01\x12\xc4\x01\n CreateTaskPushNotificationConfig\x12/.a2a.v1.CreateTaskPushNotificationConfigRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"K\xda\x41\rparent,config\x82\xd3\xe4\x93\x02\x35\"+/v1/{parent=task/*/pushNotificationConfigs}:\x06\x63onfig\x12\xae\x01\n\x1dGetTaskPushNotificationConfig\x12,.a2a.v1.GetTaskPushNotificationConfigRequest\x1a\".a2a.v1.TaskPushNotificationConfig\";\xda\x41\x04name\x82\xd3\xe4\x93\x02.\x12,/v1/{name=tasks/*/pushNotificationConfigs/*}\x12\xbe\x01\n\x1eListTaskPushNotificationConfig\x12-.a2a.v1.ListTaskPushNotificationConfigRequest\x1a..a2a.v1.ListTaskPushNotificationConfigResponse\"=\xda\x41\x06parent\x82\xd3\xe4\x93\x02.\x12,/v1/{parent=tasks/*}/pushNotificationConfigs\x12P\n\x0cGetAgentCard\x12\x1b.a2a.v1.GetAgentCardRequest\x1a\x11.a2a.v1.AgentCard\"\x10\x82\xd3\xe4\x93\x02\n\x12\x08/v1/card\x12\xa8\x01\n DeleteTaskPushNotificationConfig\x12/.a2a.v1.DeleteTaskPushNotificationConfigRequest\x1a\x16.google.protobuf.Empty\";\xda\x41\x04name\x82\xd3\xe4\x93\x02.*,/v1/{name=tasks/*/pushNotificationConfigs/*}Bi\n\ncom.a2a.v1B\x08\x41\x32\x61ProtoP\x01Z\x18google.golang.org/a2a/v1\xa2\x02\x03\x41XX\xaa\x02\x06\x41\x32\x61.V1\xca\x02\x06\x41\x32\x61\\V1\xe2\x02\x12\x41\x32\x61\\V1\\GPBMetadata\xea\x02\x07\x41\x32\x61::V1b\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ta2a.proto\x12\x06\x61\x32\x61.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xde\x01\n\x18SendMessageConfiguration\x12\x32\n\x15\x61\x63\x63\x65pted_output_modes\x18\x01 \x03(\tR\x13\x61\x63\x63\x65ptedOutputModes\x12K\n\x11push_notification\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x10pushNotification\x12%\n\x0ehistory_length\x18\x03 \x01(\x05R\rhistoryLength\x12\x1a\n\x08\x62locking\x18\x04 \x01(\x08R\x08\x62locking\"\xf1\x01\n\x04Task\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12.\n\tartifacts\x18\x04 \x03(\x0b\x32\x10.a2a.v1.ArtifactR\tartifacts\x12)\n\x07history\x18\x05 \x03(\x0b\x32\x0f.a2a.v1.MessageR\x07history\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x99\x01\n\nTaskStatus\x12\'\n\x05state\x18\x01 \x01(\x0e\x32\x11.a2a.v1.TaskStateR\x05state\x12(\n\x06update\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageR\x07message\x12\x38\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\"t\n\x04Part\x12\x14\n\x04text\x18\x01 \x01(\tH\x00R\x04text\x12&\n\x04\x66ile\x18\x02 \x01(\x0b\x32\x10.a2a.v1.FilePartH\x00R\x04\x66ile\x12&\n\x04\x64\x61ta\x18\x03 \x01(\x0b\x32\x10.a2a.v1.DataPartH\x00R\x04\x64\x61taB\x06\n\x04part\"\x7f\n\x08\x46ilePart\x12$\n\rfile_with_uri\x18\x01 \x01(\tH\x00R\x0b\x66ileWithUri\x12(\n\x0f\x66ile_with_bytes\x18\x02 \x01(\x0cH\x00R\rfileWithBytes\x12\x1b\n\tmime_type\x18\x03 \x01(\tR\x08mimeTypeB\x06\n\x04\x66ile\"7\n\x08\x44\x61taPart\x12+\n\x04\x64\x61ta\x18\x01 \x01(\x0b\x32\x17.google.protobuf.StructR\x04\x64\x61ta\"\xff\x01\n\x07Message\x12\x1d\n\nmessage_id\x18\x01 \x01(\tR\tmessageId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12\x17\n\x07task_id\x18\x03 \x01(\tR\x06taskId\x12 \n\x04role\x18\x04 \x01(\x0e\x32\x0c.a2a.v1.RoleR\x04role\x12&\n\x07\x63ontent\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x07\x63ontent\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xda\x01\n\x08\x41rtifact\x12\x1f\n\x0b\x61rtifact_id\x18\x01 \x01(\tR\nartifactId\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x04 \x01(\tR\x0b\x64\x65scription\x12\"\n\x05parts\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x05parts\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xc6\x01\n\x15TaskStatusUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12\x14\n\x05\x66inal\x18\x04 \x01(\x08R\x05\x66inal\x12\x33\n\x08metadata\x18\x05 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\xeb\x01\n\x17TaskArtifactUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12,\n\x08\x61rtifact\x18\x03 \x01(\x0b\x32\x10.a2a.v1.ArtifactR\x08\x61rtifact\x12\x16\n\x06\x61ppend\x18\x04 \x01(\x08R\x06\x61ppend\x12\x1d\n\nlast_chunk\x18\x05 \x01(\x08R\tlastChunk\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x94\x01\n\x16PushNotificationConfig\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x10\n\x03url\x18\x02 \x01(\tR\x03url\x12\x14\n\x05token\x18\x03 \x01(\tR\x05token\x12\x42\n\x0e\x61uthentication\x18\x04 \x01(\x0b\x32\x1a.a2a.v1.AuthenticationInfoR\x0e\x61uthentication\"P\n\x12\x41uthenticationInfo\x12\x18\n\x07schemes\x18\x01 \x03(\tR\x07schemes\x12 \n\x0b\x63redentials\x18\x02 \x01(\tR\x0b\x63redentials\"@\n\x0e\x41gentInterface\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\x1c\n\ttransport\x18\x02 \x01(\tR\ttransport\"\xf1\x06\n\tAgentCard\x12)\n\x10protocol_version\x18\x10 \x01(\tR\x0fprotocolVersion\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x10\n\x03url\x18\x03 \x01(\tR\x03url\x12/\n\x13preferred_transport\x18\x0e \x01(\tR\x12preferredTransport\x12K\n\x15\x61\x64\x64itional_interfaces\x18\x0f \x03(\x0b\x32\x16.a2a.v1.AgentInterfaceR\x14\x61\x64\x64itionalInterfaces\x12\x31\n\x08provider\x18\x04 \x01(\x0b\x32\x15.a2a.v1.AgentProviderR\x08provider\x12\x18\n\x07version\x18\x05 \x01(\tR\x07version\x12+\n\x11\x64ocumentation_url\x18\x06 \x01(\tR\x10\x64ocumentationUrl\x12=\n\x0c\x63\x61pabilities\x18\x07 \x01(\x0b\x32\x19.a2a.v1.AgentCapabilitiesR\x0c\x63\x61pabilities\x12Q\n\x10security_schemes\x18\x08 \x03(\x0b\x32&.a2a.v1.AgentCard.SecuritySchemesEntryR\x0fsecuritySchemes\x12,\n\x08security\x18\t \x03(\x0b\x32\x10.a2a.v1.SecurityR\x08security\x12.\n\x13\x64\x65\x66\x61ult_input_modes\x18\n \x03(\tR\x11\x64\x65\x66\x61ultInputModes\x12\x30\n\x14\x64\x65\x66\x61ult_output_modes\x18\x0b \x03(\tR\x12\x64\x65\x66\x61ultOutputModes\x12*\n\x06skills\x18\x0c \x03(\x0b\x32\x12.a2a.v1.AgentSkillR\x06skills\x12O\n$supports_authenticated_extended_card\x18\r \x01(\x08R!supportsAuthenticatedExtendedCard\x1aZ\n\x14SecuritySchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x16.a2a.v1.SecuritySchemeR\x05value:\x02\x38\x01\"E\n\rAgentProvider\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\"\n\x0corganization\x18\x02 \x01(\tR\x0corganization\"\x98\x01\n\x11\x41gentCapabilities\x12\x1c\n\tstreaming\x18\x01 \x01(\x08R\tstreaming\x12-\n\x12push_notifications\x18\x02 \x01(\x08R\x11pushNotifications\x12\x36\n\nextensions\x18\x03 \x03(\x0b\x32\x16.a2a.v1.AgentExtensionR\nextensions\"\x91\x01\n\x0e\x41gentExtension\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08required\x18\x03 \x01(\x08R\x08required\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.StructR\x06params\"\xc6\x01\n\nAgentSkill\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x03 \x01(\tR\x0b\x64\x65scription\x12\x12\n\x04tags\x18\x04 \x03(\tR\x04tags\x12\x1a\n\x08\x65xamples\x18\x05 \x03(\tR\x08\x65xamples\x12\x1f\n\x0binput_modes\x18\x06 \x03(\tR\ninputModes\x12!\n\x0coutput_modes\x18\x07 \x03(\tR\x0boutputModes\"\x8a\x01\n\x1aTaskPushNotificationConfig\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12X\n\x18push_notification_config\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x16pushNotificationConfig\" \n\nStringList\x12\x12\n\x04list\x18\x01 \x03(\tR\x04list\"\x93\x01\n\x08Security\x12\x37\n\x07schemes\x18\x01 \x03(\x0b\x32\x1d.a2a.v1.Security.SchemesEntryR\x07schemes\x1aN\n\x0cSchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x12.a2a.v1.StringListR\x05value:\x02\x38\x01\"\x91\x03\n\x0eSecurityScheme\x12U\n\x17\x61pi_key_security_scheme\x18\x01 \x01(\x0b\x32\x1c.a2a.v1.APIKeySecuritySchemeH\x00R\x14\x61piKeySecurityScheme\x12[\n\x19http_auth_security_scheme\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.HTTPAuthSecuritySchemeH\x00R\x16httpAuthSecurityScheme\x12T\n\x16oauth2_security_scheme\x18\x03 \x01(\x0b\x32\x1c.a2a.v1.OAuth2SecuritySchemeH\x00R\x14oauth2SecurityScheme\x12k\n\x1fopen_id_connect_security_scheme\x18\x04 \x01(\x0b\x32#.a2a.v1.OpenIdConnectSecuritySchemeH\x00R\x1bopenIdConnectSecuritySchemeB\x08\n\x06scheme\"h\n\x14\x41PIKeySecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08location\x18\x02 \x01(\tR\x08location\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\"w\n\x16HTTPAuthSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x16\n\x06scheme\x18\x02 \x01(\tR\x06scheme\x12#\n\rbearer_format\x18\x03 \x01(\tR\x0c\x62\x65\x61rerFormat\"b\n\x14OAuth2SecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12(\n\x05\x66lows\x18\x02 \x01(\x0b\x32\x12.a2a.v1.OAuthFlowsR\x05\x66lows\"n\n\x1bOpenIdConnectSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12-\n\x13open_id_connect_url\x18\x02 \x01(\tR\x10openIdConnectUrl\"\xb0\x02\n\nOAuthFlows\x12S\n\x12\x61uthorization_code\x18\x01 \x01(\x0b\x32\".a2a.v1.AuthorizationCodeOAuthFlowH\x00R\x11\x61uthorizationCode\x12S\n\x12\x63lient_credentials\x18\x02 \x01(\x0b\x32\".a2a.v1.ClientCredentialsOAuthFlowH\x00R\x11\x63lientCredentials\x12\x37\n\x08implicit\x18\x03 \x01(\x0b\x32\x19.a2a.v1.ImplicitOAuthFlowH\x00R\x08implicit\x12\x37\n\x08password\x18\x04 \x01(\x0b\x32\x19.a2a.v1.PasswordOAuthFlowH\x00R\x08passwordB\x06\n\x04\x66low\"\x8a\x02\n\x1a\x41uthorizationCodeOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1b\n\ttoken_url\x18\x02 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x03 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x04 \x03(\x0b\x32..a2a.v1.AuthorizationCodeOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdd\x01\n\x1a\x43lientCredentialsOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x03 \x03(\x0b\x32..a2a.v1.ClientCredentialsOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdb\x01\n\x11ImplicitOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.ImplicitOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xcb\x01\n\x11PasswordOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.PasswordOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xc1\x01\n\x12SendMessageRequest\x12.\n\x07request\x18\x01 \x01(\x0b\x32\x0f.a2a.v1.MessageB\x03\xe0\x41\x02R\x07message\x12\x46\n\rconfiguration\x18\x02 \x01(\x0b\x32 .a2a.v1.SendMessageConfigurationR\rconfiguration\x12\x33\n\x08metadata\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"P\n\x0eGetTaskRequest\x12\x17\n\x04name\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x04name\x12%\n\x0ehistory_length\x18\x02 \x01(\x05R\rhistoryLength\"\'\n\x11\x43\x61ncelTaskRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\":\n$GetTaskPushNotificationConfigRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"=\n\'DeleteTaskPushNotificationConfigRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"\xa9\x01\n\'CreateTaskPushNotificationConfigRequest\x12\x1b\n\x06parent\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x06parent\x12 \n\tconfig_id\x18\x02 \x01(\tB\x03\xe0\x41\x02R\x08\x63onfigId\x12?\n\x06\x63onfig\x18\x03 \x01(\x0b\x32\".a2a.v1.TaskPushNotificationConfigB\x03\xe0\x41\x02R\x06\x63onfig\"-\n\x17TaskSubscriptionRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"{\n%ListTaskPushNotificationConfigRequest\x12\x16\n\x06parent\x18\x01 \x01(\tR\x06parent\x12\x1b\n\tpage_size\x18\x02 \x01(\x05R\x08pageSize\x12\x1d\n\npage_token\x18\x03 \x01(\tR\tpageToken\"\x15\n\x13GetAgentCardRequest\"m\n\x13SendMessageResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12\'\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x07messageB\t\n\x07payload\"\xfa\x01\n\x0eStreamResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12\'\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x07message\x12\x44\n\rstatus_update\x18\x03 \x01(\x0b\x32\x1d.a2a.v1.TaskStatusUpdateEventH\x00R\x0cstatusUpdate\x12J\n\x0f\x61rtifact_update\x18\x04 \x01(\x0b\x32\x1f.a2a.v1.TaskArtifactUpdateEventH\x00R\x0e\x61rtifactUpdateB\t\n\x07payload\"\x8e\x01\n&ListTaskPushNotificationConfigResponse\x12<\n\x07\x63onfigs\x18\x01 \x03(\x0b\x32\".a2a.v1.TaskPushNotificationConfigR\x07\x63onfigs\x12&\n\x0fnext_page_token\x18\x02 \x01(\tR\rnextPageToken*\xfa\x01\n\tTaskState\x12\x1a\n\x16TASK_STATE_UNSPECIFIED\x10\x00\x12\x18\n\x14TASK_STATE_SUBMITTED\x10\x01\x12\x16\n\x12TASK_STATE_WORKING\x10\x02\x12\x18\n\x14TASK_STATE_COMPLETED\x10\x03\x12\x15\n\x11TASK_STATE_FAILED\x10\x04\x12\x18\n\x14TASK_STATE_CANCELLED\x10\x05\x12\x1d\n\x19TASK_STATE_INPUT_REQUIRED\x10\x06\x12\x17\n\x13TASK_STATE_REJECTED\x10\x07\x12\x1c\n\x18TASK_STATE_AUTH_REQUIRED\x10\x08*;\n\x04Role\x12\x14\n\x10ROLE_UNSPECIFIED\x10\x00\x12\r\n\tROLE_USER\x10\x01\x12\x0e\n\nROLE_AGENT\x10\x02\x32\xba\n\n\nA2AService\x12\x63\n\x0bSendMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x1b.a2a.v1.SendMessageResponse\"\x1b\x82\xd3\xe4\x93\x02\x15\"\x10/v1/message:send:\x01*\x12k\n\x14SendStreamingMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x16.a2a.v1.StreamResponse\"\x1d\x82\xd3\xe4\x93\x02\x17\"\x12/v1/message:stream:\x01*0\x01\x12R\n\x07GetTask\x12\x16.a2a.v1.GetTaskRequest\x1a\x0c.a2a.v1.Task\"!\xda\x41\x04name\x82\xd3\xe4\x93\x02\x14\x12\x12/v1/{name=tasks/*}\x12[\n\nCancelTask\x12\x19.a2a.v1.CancelTaskRequest\x1a\x0c.a2a.v1.Task\"$\x82\xd3\xe4\x93\x02\x1e\"\x19/v1/{name=tasks/*}:cancel:\x01*\x12s\n\x10TaskSubscription\x12\x1f.a2a.v1.TaskSubscriptionRequest\x1a\x16.a2a.v1.StreamResponse\"$\x82\xd3\xe4\x93\x02\x1e\x12\x1c/v1/{name=tasks/*}:subscribe0\x01\x12\xc4\x01\n CreateTaskPushNotificationConfig\x12/.a2a.v1.CreateTaskPushNotificationConfigRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"K\xda\x41\rparent,config\x82\xd3\xe4\x93\x02\x35\"+/v1/{parent=task/*/pushNotificationConfigs}:\x06\x63onfig\x12\xae\x01\n\x1dGetTaskPushNotificationConfig\x12,.a2a.v1.GetTaskPushNotificationConfigRequest\x1a\".a2a.v1.TaskPushNotificationConfig\";\xda\x41\x04name\x82\xd3\xe4\x93\x02.\x12,/v1/{name=tasks/*/pushNotificationConfigs/*}\x12\xbe\x01\n\x1eListTaskPushNotificationConfig\x12-.a2a.v1.ListTaskPushNotificationConfigRequest\x1a..a2a.v1.ListTaskPushNotificationConfigResponse\"=\xda\x41\x06parent\x82\xd3\xe4\x93\x02.\x12,/v1/{parent=tasks/*}/pushNotificationConfigs\x12P\n\x0cGetAgentCard\x12\x1b.a2a.v1.GetAgentCardRequest\x1a\x11.a2a.v1.AgentCard\"\x10\x82\xd3\xe4\x93\x02\n\x12\x08/v1/card\x12\xa8\x01\n DeleteTaskPushNotificationConfig\x12/.a2a.v1.DeleteTaskPushNotificationConfigRequest\x1a\x16.google.protobuf.Empty\";\xda\x41\x04name\x82\xd3\xe4\x93\x02.*,/v1/{name=tasks/*/pushNotificationConfigs/*}B=\n\x11\x63om.google.a2a.v1B\x03\x41\x32\x41P\x01Z\x18google.golang.org/a2a/v1\xaa\x02\x06\x41\x32\x61.V1b\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'a2a_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None - _globals['DESCRIPTOR']._serialized_options = b'\n\ncom.a2a.v1B\010A2aProtoP\001Z\030google.golang.org/a2a/v1\242\002\003AXX\252\002\006A2a.V1\312\002\006A2a\\V1\342\002\022A2a\\V1\\GPBMetadata\352\002\007A2a::V1' + _globals['DESCRIPTOR']._serialized_options = b'\n\021com.google.a2a.v1B\003A2AP\001Z\030google.golang.org/a2a/v1\252\002\006A2a.V1' _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._loaded_options = None _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_options = b'8\001' _globals['_SECURITY_SCHEMESENTRY']._loaded_options = None diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index 28922a33..ed2579bd 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -69,6 +69,10 @@ def build( app = FastAPI(**kwargs) router = APIRouter() for route, callback in self._handler.routes().items(): +<<<<<<< Updated upstream +======= + print("Attaching ", route) +>>>>>>> Stashed changes router.add_api_route( f'{rpc_url}{route[0]}', callback, diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index 247d3f2e..afc82038 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -139,10 +139,18 @@ async def event_generator( ) -> AsyncGenerator[dict[str, str]]: async for item in stream: yield {'data': item} +<<<<<<< Updated upstream return EventSourceResponse(event_generator(method(request, call_context))) except Exception as e: # Since the stream has started, we can't return a JSONResponse. # Instead, we runt the error handling logic (provides logging) +======= + return EventSourceResponse( + event_generator(method(request, call_context))) + except Exception as e: + # Since the stream has started, we can't return a JSONResponse. + # Instead, we run the error handling logic (provides logging) +>>>>>>> Stashed changes # and reraise the error and let server framework manage self._handle_error(e) raise e @@ -185,6 +193,7 @@ async def handle_authenticated_agent_card(self, request: Request) -> JSONRespons def routes(self) -> dict[Tuple[str, str], Callable[[Request],Any]]: routes = { +<<<<<<< Updated upstream ('/v1/message:send', 'POST'): ( functools.partial( self._handle_request, @@ -230,4 +239,44 @@ def routes(self) -> dict[Tuple[str, str], Callable[[Request],Any]]: routes['/v1/card'] = ( self.handle_authenticated_agent_card, 'GET') +======= + ('/v1/message:send', 'POST'): functools.partial( + self._handle_request, + self.handler.on_message_send + ), + ('/v1/message:stream', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_message_send_stream + ), + ('/v1/tasks/{id}:subscribe', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_resubscribe_to_task + ), + ('/v1/tasks/{id}', 'GET'): functools.partial( + self._handle_request, + self.handler.on_get_task + ), + ('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): + functools.partial( + self._handle_request, + self.handler.get_push_notification + ), + ('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): + functools.partial( + self._handle_request, + self.handler.set_push_notification + ), + ('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): + functools.partial( + self._handle_request, + self.handler.list_push_notifications + ), + ('/v1/tasks', 'GET'): functools.partial( + self._handle_request, + self.handler.list_tasks + ), + } + if self.agent_card.supportsAuthenticatedExtendedCard: + routes[('/v1/card', 'GET')] = self.handle_authenticated_agent_card +>>>>>>> Stashed changes return routes diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 6078180f..71a1ba80 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -117,6 +117,10 @@ async def on_message_send_stream( """ try: body = await request.body() +<<<<<<< Updated upstream +======= + print('Request body', body) +>>>>>>> Stashed changes params = a2a_pb2.SendMessageRequest() Parse(body, params) # Transform the proto object to the python internal objects diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index 06ac1123..e3d6fb6a 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -28,7 +28,7 @@ completed_task, new_task, ) - +from a2a.utils.transports import Transports __all__ = [ 'AGENT_CARD_WELL_KNOWN_PATH', @@ -49,4 +49,5 @@ 'new_data_artifact', 'new_task', 'new_text_artifact', + 'Transports', ] diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 423a8388..23a65e0d 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -286,6 +286,21 @@ def agent_card( supports_authenticated_extended_card=bool( card.supportsAuthenticatedExtendedCard ), + preferred_transport=card.preferred_transport, + protocol_version=card.protocol_version, + additional_interfaces=[ + cls.agent_interface(x) for x in card.additional_interfaces + ] if card.additional_interfaces else None, + ) + + @classmethod + def agent_interface( + cls, + interface: types.AgentInterface, + ) -> a2a_pb2.AgentInterface: + return a2a_pb2.AgentInterface( + transport=interface.transport, + url=interface.url, ) @classmethod @@ -662,7 +677,22 @@ def agent_card( skills=[cls.skill(x) for x in card.skills] if card.skills else [], url=card.url, version=card.version, - supportsAuthenticatedExtendedCard=card.supports_authenticated_extended_card, + supports_authenticated_extended_card=card.supports_authenticated_extended_card, + preferred_transport=card.preferred_transport, + protocol_version=card.protocol_version, + additional_interfaces=[ + cls.agent_interface(x) for x in card.additional_interfaces + ] if card.additional_interfaces else None, + ) + + @classmethod + def agent_interface( + cls, + interface: a2a_pb2.AgentInterface, + ) -> types.AgentInterface: + return types.AgentInterface( + transport=interface.transport, + url=interface.url, ) @classmethod @@ -793,6 +823,22 @@ def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: ), ) + @classmethod + def stream_response( + cls, + response: a2a_pb2.StreamResponse, + ) -> (types.Message + | types.Task + | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent): + if response.HasField('msg'): + return cls.message(response.msg) + if response.HasField('task'): + return cls.task(response.task) + if response.HasField('status_update'): + return cls.task_status_update_event(response.status_update) + return cls.task_artifact_update_event(response.artifact_update) + @classmethod def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: return types.AgentSkill( diff --git a/src/a2a/utils/transports.py b/src/a2a/utils/transports.py new file mode 100644 index 00000000..33a8f9ed --- /dev/null +++ b/src/a2a/utils/transports.py @@ -0,0 +1,7 @@ +"""Defines standard protocol transport labels.""" +from enum import Enum + +class Transports(str, Enum): + GRPC = "GRPC" + JSONRPC = "JSONRPC" + RESTful = "HTTP+JSON" From 89c319f385d6d3848070bf9e1974567f4c93b035 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 24 Jul 2025 04:41:01 +0000 Subject: [PATCH 05/17] Fixes to initial client pr --- src/a2a/client/__init__.py | 8 ++++++++ src/a2a/grpc/a2a_pb2.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index f0cfcca7..8b33bc75 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -21,6 +21,11 @@ GrpcClient, NewGrpcClient, ) +from a2a.client.rest_client import ( + RestTransportClient, + RestClient, + NewRestfulClient, +) from a2a.client.helpers import create_text_message_object from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.client import ( @@ -68,4 +73,7 @@ 'JsonRpcTransportClient', 'NewJsonRpcClient', 'minimal_agent_card', + 'RestTransportClient', + 'RestClient', + 'NewRestfulClient', ] diff --git a/src/a2a/grpc/a2a_pb2.py b/src/a2a/grpc/a2a_pb2.py index e8304632..e11d6ebf 100644 --- a/src/a2a/grpc/a2a_pb2.py +++ b/src/a2a/grpc/a2a_pb2.py @@ -30,14 +30,14 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ta2a.proto\x12\x06\x61\x32\x61.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xde\x01\n\x18SendMessageConfiguration\x12\x32\n\x15\x61\x63\x63\x65pted_output_modes\x18\x01 \x03(\tR\x13\x61\x63\x63\x65ptedOutputModes\x12K\n\x11push_notification\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x10pushNotification\x12%\n\x0ehistory_length\x18\x03 \x01(\x05R\rhistoryLength\x12\x1a\n\x08\x62locking\x18\x04 \x01(\x08R\x08\x62locking\"\xf1\x01\n\x04Task\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12.\n\tartifacts\x18\x04 \x03(\x0b\x32\x10.a2a.v1.ArtifactR\tartifacts\x12)\n\x07history\x18\x05 \x03(\x0b\x32\x0f.a2a.v1.MessageR\x07history\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x99\x01\n\nTaskStatus\x12\'\n\x05state\x18\x01 \x01(\x0e\x32\x11.a2a.v1.TaskStateR\x05state\x12(\n\x06update\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageR\x07message\x12\x38\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\"t\n\x04Part\x12\x14\n\x04text\x18\x01 \x01(\tH\x00R\x04text\x12&\n\x04\x66ile\x18\x02 \x01(\x0b\x32\x10.a2a.v1.FilePartH\x00R\x04\x66ile\x12&\n\x04\x64\x61ta\x18\x03 \x01(\x0b\x32\x10.a2a.v1.DataPartH\x00R\x04\x64\x61taB\x06\n\x04part\"\x7f\n\x08\x46ilePart\x12$\n\rfile_with_uri\x18\x01 \x01(\tH\x00R\x0b\x66ileWithUri\x12(\n\x0f\x66ile_with_bytes\x18\x02 \x01(\x0cH\x00R\rfileWithBytes\x12\x1b\n\tmime_type\x18\x03 \x01(\tR\x08mimeTypeB\x06\n\x04\x66ile\"7\n\x08\x44\x61taPart\x12+\n\x04\x64\x61ta\x18\x01 \x01(\x0b\x32\x17.google.protobuf.StructR\x04\x64\x61ta\"\xff\x01\n\x07Message\x12\x1d\n\nmessage_id\x18\x01 \x01(\tR\tmessageId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12\x17\n\x07task_id\x18\x03 \x01(\tR\x06taskId\x12 \n\x04role\x18\x04 \x01(\x0e\x32\x0c.a2a.v1.RoleR\x04role\x12&\n\x07\x63ontent\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x07\x63ontent\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xda\x01\n\x08\x41rtifact\x12\x1f\n\x0b\x61rtifact_id\x18\x01 \x01(\tR\nartifactId\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x04 \x01(\tR\x0b\x64\x65scription\x12\"\n\x05parts\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x05parts\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xc6\x01\n\x15TaskStatusUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12\x14\n\x05\x66inal\x18\x04 \x01(\x08R\x05\x66inal\x12\x33\n\x08metadata\x18\x05 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\xeb\x01\n\x17TaskArtifactUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12,\n\x08\x61rtifact\x18\x03 \x01(\x0b\x32\x10.a2a.v1.ArtifactR\x08\x61rtifact\x12\x16\n\x06\x61ppend\x18\x04 \x01(\x08R\x06\x61ppend\x12\x1d\n\nlast_chunk\x18\x05 \x01(\x08R\tlastChunk\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x94\x01\n\x16PushNotificationConfig\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x10\n\x03url\x18\x02 \x01(\tR\x03url\x12\x14\n\x05token\x18\x03 \x01(\tR\x05token\x12\x42\n\x0e\x61uthentication\x18\x04 \x01(\x0b\x32\x1a.a2a.v1.AuthenticationInfoR\x0e\x61uthentication\"P\n\x12\x41uthenticationInfo\x12\x18\n\x07schemes\x18\x01 \x03(\tR\x07schemes\x12 \n\x0b\x63redentials\x18\x02 \x01(\tR\x0b\x63redentials\"@\n\x0e\x41gentInterface\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\x1c\n\ttransport\x18\x02 \x01(\tR\ttransport\"\xf1\x06\n\tAgentCard\x12)\n\x10protocol_version\x18\x10 \x01(\tR\x0fprotocolVersion\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x10\n\x03url\x18\x03 \x01(\tR\x03url\x12/\n\x13preferred_transport\x18\x0e \x01(\tR\x12preferredTransport\x12K\n\x15\x61\x64\x64itional_interfaces\x18\x0f \x03(\x0b\x32\x16.a2a.v1.AgentInterfaceR\x14\x61\x64\x64itionalInterfaces\x12\x31\n\x08provider\x18\x04 \x01(\x0b\x32\x15.a2a.v1.AgentProviderR\x08provider\x12\x18\n\x07version\x18\x05 \x01(\tR\x07version\x12+\n\x11\x64ocumentation_url\x18\x06 \x01(\tR\x10\x64ocumentationUrl\x12=\n\x0c\x63\x61pabilities\x18\x07 \x01(\x0b\x32\x19.a2a.v1.AgentCapabilitiesR\x0c\x63\x61pabilities\x12Q\n\x10security_schemes\x18\x08 \x03(\x0b\x32&.a2a.v1.AgentCard.SecuritySchemesEntryR\x0fsecuritySchemes\x12,\n\x08security\x18\t \x03(\x0b\x32\x10.a2a.v1.SecurityR\x08security\x12.\n\x13\x64\x65\x66\x61ult_input_modes\x18\n \x03(\tR\x11\x64\x65\x66\x61ultInputModes\x12\x30\n\x14\x64\x65\x66\x61ult_output_modes\x18\x0b \x03(\tR\x12\x64\x65\x66\x61ultOutputModes\x12*\n\x06skills\x18\x0c \x03(\x0b\x32\x12.a2a.v1.AgentSkillR\x06skills\x12O\n$supports_authenticated_extended_card\x18\r \x01(\x08R!supportsAuthenticatedExtendedCard\x1aZ\n\x14SecuritySchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x16.a2a.v1.SecuritySchemeR\x05value:\x02\x38\x01\"E\n\rAgentProvider\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\"\n\x0corganization\x18\x02 \x01(\tR\x0corganization\"\x98\x01\n\x11\x41gentCapabilities\x12\x1c\n\tstreaming\x18\x01 \x01(\x08R\tstreaming\x12-\n\x12push_notifications\x18\x02 \x01(\x08R\x11pushNotifications\x12\x36\n\nextensions\x18\x03 \x03(\x0b\x32\x16.a2a.v1.AgentExtensionR\nextensions\"\x91\x01\n\x0e\x41gentExtension\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08required\x18\x03 \x01(\x08R\x08required\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.StructR\x06params\"\xc6\x01\n\nAgentSkill\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x03 \x01(\tR\x0b\x64\x65scription\x12\x12\n\x04tags\x18\x04 \x03(\tR\x04tags\x12\x1a\n\x08\x65xamples\x18\x05 \x03(\tR\x08\x65xamples\x12\x1f\n\x0binput_modes\x18\x06 \x03(\tR\ninputModes\x12!\n\x0coutput_modes\x18\x07 \x03(\tR\x0boutputModes\"\x8a\x01\n\x1aTaskPushNotificationConfig\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12X\n\x18push_notification_config\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x16pushNotificationConfig\" \n\nStringList\x12\x12\n\x04list\x18\x01 \x03(\tR\x04list\"\x93\x01\n\x08Security\x12\x37\n\x07schemes\x18\x01 \x03(\x0b\x32\x1d.a2a.v1.Security.SchemesEntryR\x07schemes\x1aN\n\x0cSchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x12.a2a.v1.StringListR\x05value:\x02\x38\x01\"\x91\x03\n\x0eSecurityScheme\x12U\n\x17\x61pi_key_security_scheme\x18\x01 \x01(\x0b\x32\x1c.a2a.v1.APIKeySecuritySchemeH\x00R\x14\x61piKeySecurityScheme\x12[\n\x19http_auth_security_scheme\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.HTTPAuthSecuritySchemeH\x00R\x16httpAuthSecurityScheme\x12T\n\x16oauth2_security_scheme\x18\x03 \x01(\x0b\x32\x1c.a2a.v1.OAuth2SecuritySchemeH\x00R\x14oauth2SecurityScheme\x12k\n\x1fopen_id_connect_security_scheme\x18\x04 \x01(\x0b\x32#.a2a.v1.OpenIdConnectSecuritySchemeH\x00R\x1bopenIdConnectSecuritySchemeB\x08\n\x06scheme\"h\n\x14\x41PIKeySecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08location\x18\x02 \x01(\tR\x08location\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\"w\n\x16HTTPAuthSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x16\n\x06scheme\x18\x02 \x01(\tR\x06scheme\x12#\n\rbearer_format\x18\x03 \x01(\tR\x0c\x62\x65\x61rerFormat\"b\n\x14OAuth2SecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12(\n\x05\x66lows\x18\x02 \x01(\x0b\x32\x12.a2a.v1.OAuthFlowsR\x05\x66lows\"n\n\x1bOpenIdConnectSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12-\n\x13open_id_connect_url\x18\x02 \x01(\tR\x10openIdConnectUrl\"\xb0\x02\n\nOAuthFlows\x12S\n\x12\x61uthorization_code\x18\x01 \x01(\x0b\x32\".a2a.v1.AuthorizationCodeOAuthFlowH\x00R\x11\x61uthorizationCode\x12S\n\x12\x63lient_credentials\x18\x02 \x01(\x0b\x32\".a2a.v1.ClientCredentialsOAuthFlowH\x00R\x11\x63lientCredentials\x12\x37\n\x08implicit\x18\x03 \x01(\x0b\x32\x19.a2a.v1.ImplicitOAuthFlowH\x00R\x08implicit\x12\x37\n\x08password\x18\x04 \x01(\x0b\x32\x19.a2a.v1.PasswordOAuthFlowH\x00R\x08passwordB\x06\n\x04\x66low\"\x8a\x02\n\x1a\x41uthorizationCodeOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1b\n\ttoken_url\x18\x02 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x03 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x04 \x03(\x0b\x32..a2a.v1.AuthorizationCodeOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdd\x01\n\x1a\x43lientCredentialsOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x03 \x03(\x0b\x32..a2a.v1.ClientCredentialsOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdb\x01\n\x11ImplicitOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.ImplicitOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xcb\x01\n\x11PasswordOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.PasswordOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xc1\x01\n\x12SendMessageRequest\x12.\n\x07request\x18\x01 \x01(\x0b\x32\x0f.a2a.v1.MessageB\x03\xe0\x41\x02R\x07message\x12\x46\n\rconfiguration\x18\x02 \x01(\x0b\x32 .a2a.v1.SendMessageConfigurationR\rconfiguration\x12\x33\n\x08metadata\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"P\n\x0eGetTaskRequest\x12\x17\n\x04name\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x04name\x12%\n\x0ehistory_length\x18\x02 \x01(\x05R\rhistoryLength\"\'\n\x11\x43\x61ncelTaskRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\":\n$GetTaskPushNotificationConfigRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"=\n\'DeleteTaskPushNotificationConfigRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"\xa9\x01\n\'CreateTaskPushNotificationConfigRequest\x12\x1b\n\x06parent\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x06parent\x12 \n\tconfig_id\x18\x02 \x01(\tB\x03\xe0\x41\x02R\x08\x63onfigId\x12?\n\x06\x63onfig\x18\x03 \x01(\x0b\x32\".a2a.v1.TaskPushNotificationConfigB\x03\xe0\x41\x02R\x06\x63onfig\"-\n\x17TaskSubscriptionRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"{\n%ListTaskPushNotificationConfigRequest\x12\x16\n\x06parent\x18\x01 \x01(\tR\x06parent\x12\x1b\n\tpage_size\x18\x02 \x01(\x05R\x08pageSize\x12\x1d\n\npage_token\x18\x03 \x01(\tR\tpageToken\"\x15\n\x13GetAgentCardRequest\"m\n\x13SendMessageResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12\'\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x07messageB\t\n\x07payload\"\xfa\x01\n\x0eStreamResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12\'\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x07message\x12\x44\n\rstatus_update\x18\x03 \x01(\x0b\x32\x1d.a2a.v1.TaskStatusUpdateEventH\x00R\x0cstatusUpdate\x12J\n\x0f\x61rtifact_update\x18\x04 \x01(\x0b\x32\x1f.a2a.v1.TaskArtifactUpdateEventH\x00R\x0e\x61rtifactUpdateB\t\n\x07payload\"\x8e\x01\n&ListTaskPushNotificationConfigResponse\x12<\n\x07\x63onfigs\x18\x01 \x03(\x0b\x32\".a2a.v1.TaskPushNotificationConfigR\x07\x63onfigs\x12&\n\x0fnext_page_token\x18\x02 \x01(\tR\rnextPageToken*\xfa\x01\n\tTaskState\x12\x1a\n\x16TASK_STATE_UNSPECIFIED\x10\x00\x12\x18\n\x14TASK_STATE_SUBMITTED\x10\x01\x12\x16\n\x12TASK_STATE_WORKING\x10\x02\x12\x18\n\x14TASK_STATE_COMPLETED\x10\x03\x12\x15\n\x11TASK_STATE_FAILED\x10\x04\x12\x18\n\x14TASK_STATE_CANCELLED\x10\x05\x12\x1d\n\x19TASK_STATE_INPUT_REQUIRED\x10\x06\x12\x17\n\x13TASK_STATE_REJECTED\x10\x07\x12\x1c\n\x18TASK_STATE_AUTH_REQUIRED\x10\x08*;\n\x04Role\x12\x14\n\x10ROLE_UNSPECIFIED\x10\x00\x12\r\n\tROLE_USER\x10\x01\x12\x0e\n\nROLE_AGENT\x10\x02\x32\xba\n\n\nA2AService\x12\x63\n\x0bSendMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x1b.a2a.v1.SendMessageResponse\"\x1b\x82\xd3\xe4\x93\x02\x15\"\x10/v1/message:send:\x01*\x12k\n\x14SendStreamingMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x16.a2a.v1.StreamResponse\"\x1d\x82\xd3\xe4\x93\x02\x17\"\x12/v1/message:stream:\x01*0\x01\x12R\n\x07GetTask\x12\x16.a2a.v1.GetTaskRequest\x1a\x0c.a2a.v1.Task\"!\xda\x41\x04name\x82\xd3\xe4\x93\x02\x14\x12\x12/v1/{name=tasks/*}\x12[\n\nCancelTask\x12\x19.a2a.v1.CancelTaskRequest\x1a\x0c.a2a.v1.Task\"$\x82\xd3\xe4\x93\x02\x1e\"\x19/v1/{name=tasks/*}:cancel:\x01*\x12s\n\x10TaskSubscription\x12\x1f.a2a.v1.TaskSubscriptionRequest\x1a\x16.a2a.v1.StreamResponse\"$\x82\xd3\xe4\x93\x02\x1e\x12\x1c/v1/{name=tasks/*}:subscribe0\x01\x12\xc4\x01\n CreateTaskPushNotificationConfig\x12/.a2a.v1.CreateTaskPushNotificationConfigRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"K\xda\x41\rparent,config\x82\xd3\xe4\x93\x02\x35\"+/v1/{parent=task/*/pushNotificationConfigs}:\x06\x63onfig\x12\xae\x01\n\x1dGetTaskPushNotificationConfig\x12,.a2a.v1.GetTaskPushNotificationConfigRequest\x1a\".a2a.v1.TaskPushNotificationConfig\";\xda\x41\x04name\x82\xd3\xe4\x93\x02.\x12,/v1/{name=tasks/*/pushNotificationConfigs/*}\x12\xbe\x01\n\x1eListTaskPushNotificationConfig\x12-.a2a.v1.ListTaskPushNotificationConfigRequest\x1a..a2a.v1.ListTaskPushNotificationConfigResponse\"=\xda\x41\x06parent\x82\xd3\xe4\x93\x02.\x12,/v1/{parent=tasks/*}/pushNotificationConfigs\x12P\n\x0cGetAgentCard\x12\x1b.a2a.v1.GetAgentCardRequest\x1a\x11.a2a.v1.AgentCard\"\x10\x82\xd3\xe4\x93\x02\n\x12\x08/v1/card\x12\xa8\x01\n DeleteTaskPushNotificationConfig\x12/.a2a.v1.DeleteTaskPushNotificationConfigRequest\x1a\x16.google.protobuf.Empty\";\xda\x41\x04name\x82\xd3\xe4\x93\x02.*,/v1/{name=tasks/*/pushNotificationConfigs/*}B=\n\x11\x63om.google.a2a.v1B\x03\x41\x32\x41P\x01Z\x18google.golang.org/a2a/v1\xaa\x02\x06\x41\x32\x61.V1b\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ta2a.proto\x12\x06\x61\x32\x61.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xde\x01\n\x18SendMessageConfiguration\x12\x32\n\x15\x61\x63\x63\x65pted_output_modes\x18\x01 \x03(\tR\x13\x61\x63\x63\x65ptedOutputModes\x12K\n\x11push_notification\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x10pushNotification\x12%\n\x0ehistory_length\x18\x03 \x01(\x05R\rhistoryLength\x12\x1a\n\x08\x62locking\x18\x04 \x01(\x08R\x08\x62locking\"\xf1\x01\n\x04Task\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12.\n\tartifacts\x18\x04 \x03(\x0b\x32\x10.a2a.v1.ArtifactR\tartifacts\x12)\n\x07history\x18\x05 \x03(\x0b\x32\x0f.a2a.v1.MessageR\x07history\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x99\x01\n\nTaskStatus\x12\'\n\x05state\x18\x01 \x01(\x0e\x32\x11.a2a.v1.TaskStateR\x05state\x12(\n\x06update\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageR\x07message\x12\x38\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\"t\n\x04Part\x12\x14\n\x04text\x18\x01 \x01(\tH\x00R\x04text\x12&\n\x04\x66ile\x18\x02 \x01(\x0b\x32\x10.a2a.v1.FilePartH\x00R\x04\x66ile\x12&\n\x04\x64\x61ta\x18\x03 \x01(\x0b\x32\x10.a2a.v1.DataPartH\x00R\x04\x64\x61taB\x06\n\x04part\"\x7f\n\x08\x46ilePart\x12$\n\rfile_with_uri\x18\x01 \x01(\tH\x00R\x0b\x66ileWithUri\x12(\n\x0f\x66ile_with_bytes\x18\x02 \x01(\x0cH\x00R\rfileWithBytes\x12\x1b\n\tmime_type\x18\x03 \x01(\tR\x08mimeTypeB\x06\n\x04\x66ile\"7\n\x08\x44\x61taPart\x12+\n\x04\x64\x61ta\x18\x01 \x01(\x0b\x32\x17.google.protobuf.StructR\x04\x64\x61ta\"\xff\x01\n\x07Message\x12\x1d\n\nmessage_id\x18\x01 \x01(\tR\tmessageId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12\x17\n\x07task_id\x18\x03 \x01(\tR\x06taskId\x12 \n\x04role\x18\x04 \x01(\x0e\x32\x0c.a2a.v1.RoleR\x04role\x12&\n\x07\x63ontent\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x07\x63ontent\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xda\x01\n\x08\x41rtifact\x12\x1f\n\x0b\x61rtifact_id\x18\x01 \x01(\tR\nartifactId\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x04 \x01(\tR\x0b\x64\x65scription\x12\"\n\x05parts\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x05parts\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xc6\x01\n\x15TaskStatusUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12\x14\n\x05\x66inal\x18\x04 \x01(\x08R\x05\x66inal\x12\x33\n\x08metadata\x18\x05 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\xeb\x01\n\x17TaskArtifactUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12,\n\x08\x61rtifact\x18\x03 \x01(\x0b\x32\x10.a2a.v1.ArtifactR\x08\x61rtifact\x12\x16\n\x06\x61ppend\x18\x04 \x01(\x08R\x06\x61ppend\x12\x1d\n\nlast_chunk\x18\x05 \x01(\x08R\tlastChunk\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x94\x01\n\x16PushNotificationConfig\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x10\n\x03url\x18\x02 \x01(\tR\x03url\x12\x14\n\x05token\x18\x03 \x01(\tR\x05token\x12\x42\n\x0e\x61uthentication\x18\x04 \x01(\x0b\x32\x1a.a2a.v1.AuthenticationInfoR\x0e\x61uthentication\"P\n\x12\x41uthenticationInfo\x12\x18\n\x07schemes\x18\x01 \x03(\tR\x07schemes\x12 \n\x0b\x63redentials\x18\x02 \x01(\tR\x0b\x63redentials\"@\n\x0e\x41gentInterface\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\x1c\n\ttransport\x18\x02 \x01(\tR\ttransport\"\xf1\x06\n\tAgentCard\x12)\n\x10protocol_version\x18\x10 \x01(\tR\x0fprotocolVersion\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x10\n\x03url\x18\x03 \x01(\tR\x03url\x12/\n\x13preferred_transport\x18\x0e \x01(\tR\x12preferredTransport\x12K\n\x15\x61\x64\x64itional_interfaces\x18\x0f \x03(\x0b\x32\x16.a2a.v1.AgentInterfaceR\x14\x61\x64\x64itionalInterfaces\x12\x31\n\x08provider\x18\x04 \x01(\x0b\x32\x15.a2a.v1.AgentProviderR\x08provider\x12\x18\n\x07version\x18\x05 \x01(\tR\x07version\x12+\n\x11\x64ocumentation_url\x18\x06 \x01(\tR\x10\x64ocumentationUrl\x12=\n\x0c\x63\x61pabilities\x18\x07 \x01(\x0b\x32\x19.a2a.v1.AgentCapabilitiesR\x0c\x63\x61pabilities\x12Q\n\x10security_schemes\x18\x08 \x03(\x0b\x32&.a2a.v1.AgentCard.SecuritySchemesEntryR\x0fsecuritySchemes\x12,\n\x08security\x18\t \x03(\x0b\x32\x10.a2a.v1.SecurityR\x08security\x12.\n\x13\x64\x65\x66\x61ult_input_modes\x18\n \x03(\tR\x11\x64\x65\x66\x61ultInputModes\x12\x30\n\x14\x64\x65\x66\x61ult_output_modes\x18\x0b \x03(\tR\x12\x64\x65\x66\x61ultOutputModes\x12*\n\x06skills\x18\x0c \x03(\x0b\x32\x12.a2a.v1.AgentSkillR\x06skills\x12O\n$supports_authenticated_extended_card\x18\r \x01(\x08R!supportsAuthenticatedExtendedCard\x1aZ\n\x14SecuritySchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x16.a2a.v1.SecuritySchemeR\x05value:\x02\x38\x01\"E\n\rAgentProvider\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\"\n\x0corganization\x18\x02 \x01(\tR\x0corganization\"\x98\x01\n\x11\x41gentCapabilities\x12\x1c\n\tstreaming\x18\x01 \x01(\x08R\tstreaming\x12-\n\x12push_notifications\x18\x02 \x01(\x08R\x11pushNotifications\x12\x36\n\nextensions\x18\x03 \x03(\x0b\x32\x16.a2a.v1.AgentExtensionR\nextensions\"\x91\x01\n\x0e\x41gentExtension\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08required\x18\x03 \x01(\x08R\x08required\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.StructR\x06params\"\xc6\x01\n\nAgentSkill\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x03 \x01(\tR\x0b\x64\x65scription\x12\x12\n\x04tags\x18\x04 \x03(\tR\x04tags\x12\x1a\n\x08\x65xamples\x18\x05 \x03(\tR\x08\x65xamples\x12\x1f\n\x0binput_modes\x18\x06 \x03(\tR\ninputModes\x12!\n\x0coutput_modes\x18\x07 \x03(\tR\x0boutputModes\"\x8a\x01\n\x1aTaskPushNotificationConfig\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12X\n\x18push_notification_config\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x16pushNotificationConfig\" \n\nStringList\x12\x12\n\x04list\x18\x01 \x03(\tR\x04list\"\x93\x01\n\x08Security\x12\x37\n\x07schemes\x18\x01 \x03(\x0b\x32\x1d.a2a.v1.Security.SchemesEntryR\x07schemes\x1aN\n\x0cSchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x12.a2a.v1.StringListR\x05value:\x02\x38\x01\"\x91\x03\n\x0eSecurityScheme\x12U\n\x17\x61pi_key_security_scheme\x18\x01 \x01(\x0b\x32\x1c.a2a.v1.APIKeySecuritySchemeH\x00R\x14\x61piKeySecurityScheme\x12[\n\x19http_auth_security_scheme\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.HTTPAuthSecuritySchemeH\x00R\x16httpAuthSecurityScheme\x12T\n\x16oauth2_security_scheme\x18\x03 \x01(\x0b\x32\x1c.a2a.v1.OAuth2SecuritySchemeH\x00R\x14oauth2SecurityScheme\x12k\n\x1fopen_id_connect_security_scheme\x18\x04 \x01(\x0b\x32#.a2a.v1.OpenIdConnectSecuritySchemeH\x00R\x1bopenIdConnectSecuritySchemeB\x08\n\x06scheme\"h\n\x14\x41PIKeySecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08location\x18\x02 \x01(\tR\x08location\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\"w\n\x16HTTPAuthSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x16\n\x06scheme\x18\x02 \x01(\tR\x06scheme\x12#\n\rbearer_format\x18\x03 \x01(\tR\x0c\x62\x65\x61rerFormat\"b\n\x14OAuth2SecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12(\n\x05\x66lows\x18\x02 \x01(\x0b\x32\x12.a2a.v1.OAuthFlowsR\x05\x66lows\"n\n\x1bOpenIdConnectSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12-\n\x13open_id_connect_url\x18\x02 \x01(\tR\x10openIdConnectUrl\"\xb0\x02\n\nOAuthFlows\x12S\n\x12\x61uthorization_code\x18\x01 \x01(\x0b\x32\".a2a.v1.AuthorizationCodeOAuthFlowH\x00R\x11\x61uthorizationCode\x12S\n\x12\x63lient_credentials\x18\x02 \x01(\x0b\x32\".a2a.v1.ClientCredentialsOAuthFlowH\x00R\x11\x63lientCredentials\x12\x37\n\x08implicit\x18\x03 \x01(\x0b\x32\x19.a2a.v1.ImplicitOAuthFlowH\x00R\x08implicit\x12\x37\n\x08password\x18\x04 \x01(\x0b\x32\x19.a2a.v1.PasswordOAuthFlowH\x00R\x08passwordB\x06\n\x04\x66low\"\x8a\x02\n\x1a\x41uthorizationCodeOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1b\n\ttoken_url\x18\x02 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x03 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x04 \x03(\x0b\x32..a2a.v1.AuthorizationCodeOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdd\x01\n\x1a\x43lientCredentialsOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x03 \x03(\x0b\x32..a2a.v1.ClientCredentialsOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdb\x01\n\x11ImplicitOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.ImplicitOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xcb\x01\n\x11PasswordOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.PasswordOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xc1\x01\n\x12SendMessageRequest\x12.\n\x07request\x18\x01 \x01(\x0b\x32\x0f.a2a.v1.MessageB\x03\xe0\x41\x02R\x07request\x12\x46\n\rconfiguration\x18\x02 \x01(\x0b\x32 .a2a.v1.SendMessageConfigurationR\rconfiguration\x12\x33\n\x08metadata\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"P\n\x0eGetTaskRequest\x12\x17\n\x04name\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x04name\x12%\n\x0ehistory_length\x18\x02 \x01(\x05R\rhistoryLength\"\'\n\x11\x43\x61ncelTaskRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\":\n$GetTaskPushNotificationConfigRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"=\n\'DeleteTaskPushNotificationConfigRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"\xa9\x01\n\'CreateTaskPushNotificationConfigRequest\x12\x1b\n\x06parent\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x06parent\x12 \n\tconfig_id\x18\x02 \x01(\tB\x03\xe0\x41\x02R\x08\x63onfigId\x12?\n\x06\x63onfig\x18\x03 \x01(\x0b\x32\".a2a.v1.TaskPushNotificationConfigB\x03\xe0\x41\x02R\x06\x63onfig\"-\n\x17TaskSubscriptionRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"{\n%ListTaskPushNotificationConfigRequest\x12\x16\n\x06parent\x18\x01 \x01(\tR\x06parent\x12\x1b\n\tpage_size\x18\x02 \x01(\x05R\x08pageSize\x12\x1d\n\npage_token\x18\x03 \x01(\tR\tpageToken\"\x15\n\x13GetAgentCardRequest\"m\n\x13SendMessageResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12\'\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x07messageB\t\n\x07payload\"\xfa\x01\n\x0eStreamResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12\'\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x07message\x12\x44\n\rstatus_update\x18\x03 \x01(\x0b\x32\x1d.a2a.v1.TaskStatusUpdateEventH\x00R\x0cstatusUpdate\x12J\n\x0f\x61rtifact_update\x18\x04 \x01(\x0b\x32\x1f.a2a.v1.TaskArtifactUpdateEventH\x00R\x0e\x61rtifactUpdateB\t\n\x07payload\"\x8e\x01\n&ListTaskPushNotificationConfigResponse\x12<\n\x07\x63onfigs\x18\x01 \x03(\x0b\x32\".a2a.v1.TaskPushNotificationConfigR\x07\x63onfigs\x12&\n\x0fnext_page_token\x18\x02 \x01(\tR\rnextPageToken*\xfa\x01\n\tTaskState\x12\x1a\n\x16TASK_STATE_UNSPECIFIED\x10\x00\x12\x18\n\x14TASK_STATE_SUBMITTED\x10\x01\x12\x16\n\x12TASK_STATE_WORKING\x10\x02\x12\x18\n\x14TASK_STATE_COMPLETED\x10\x03\x12\x15\n\x11TASK_STATE_FAILED\x10\x04\x12\x18\n\x14TASK_STATE_CANCELLED\x10\x05\x12\x1d\n\x19TASK_STATE_INPUT_REQUIRED\x10\x06\x12\x17\n\x13TASK_STATE_REJECTED\x10\x07\x12\x1c\n\x18TASK_STATE_AUTH_REQUIRED\x10\x08*;\n\x04Role\x12\x14\n\x10ROLE_UNSPECIFIED\x10\x00\x12\r\n\tROLE_USER\x10\x01\x12\x0e\n\nROLE_AGENT\x10\x02\x32\xba\n\n\nA2AService\x12\x63\n\x0bSendMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x1b.a2a.v1.SendMessageResponse\"\x1b\x82\xd3\xe4\x93\x02\x15\"\x10/v1/message:send:\x01*\x12k\n\x14SendStreamingMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x16.a2a.v1.StreamResponse\"\x1d\x82\xd3\xe4\x93\x02\x17\"\x12/v1/message:stream:\x01*0\x01\x12R\n\x07GetTask\x12\x16.a2a.v1.GetTaskRequest\x1a\x0c.a2a.v1.Task\"!\xda\x41\x04name\x82\xd3\xe4\x93\x02\x14\x12\x12/v1/{name=tasks/*}\x12[\n\nCancelTask\x12\x19.a2a.v1.CancelTaskRequest\x1a\x0c.a2a.v1.Task\"$\x82\xd3\xe4\x93\x02\x1e\"\x19/v1/{name=tasks/*}:cancel:\x01*\x12s\n\x10TaskSubscription\x12\x1f.a2a.v1.TaskSubscriptionRequest\x1a\x16.a2a.v1.StreamResponse\"$\x82\xd3\xe4\x93\x02\x1e\x12\x1c/v1/{name=tasks/*}:subscribe0\x01\x12\xc4\x01\n CreateTaskPushNotificationConfig\x12/.a2a.v1.CreateTaskPushNotificationConfigRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"K\xda\x41\rparent,config\x82\xd3\xe4\x93\x02\x35\"+/v1/{parent=task/*/pushNotificationConfigs}:\x06\x63onfig\x12\xae\x01\n\x1dGetTaskPushNotificationConfig\x12,.a2a.v1.GetTaskPushNotificationConfigRequest\x1a\".a2a.v1.TaskPushNotificationConfig\";\xda\x41\x04name\x82\xd3\xe4\x93\x02.\x12,/v1/{name=tasks/*/pushNotificationConfigs/*}\x12\xbe\x01\n\x1eListTaskPushNotificationConfig\x12-.a2a.v1.ListTaskPushNotificationConfigRequest\x1a..a2a.v1.ListTaskPushNotificationConfigResponse\"=\xda\x41\x06parent\x82\xd3\xe4\x93\x02.\x12,/v1/{parent=tasks/*}/pushNotificationConfigs\x12P\n\x0cGetAgentCard\x12\x1b.a2a.v1.GetAgentCardRequest\x1a\x11.a2a.v1.AgentCard\"\x10\x82\xd3\xe4\x93\x02\n\x12\x08/v1/card\x12\xa8\x01\n DeleteTaskPushNotificationConfig\x12/.a2a.v1.DeleteTaskPushNotificationConfigRequest\x1a\x16.google.protobuf.Empty\";\xda\x41\x04name\x82\xd3\xe4\x93\x02.*,/v1/{name=tasks/*/pushNotificationConfigs/*}Bi\n\ncom.a2a.v1B\x08\x41\x32\x61ProtoP\x01Z\x18google.golang.org/a2a/v1\xa2\x02\x03\x41XX\xaa\x02\x06\x41\x32\x61.V1\xca\x02\x06\x41\x32\x61\\V1\xe2\x02\x12\x41\x32\x61\\V1\\GPBMetadata\xea\x02\x07\x41\x32\x61::V1b\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'a2a_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None - _globals['DESCRIPTOR']._serialized_options = b'\n\021com.google.a2a.v1B\003A2AP\001Z\030google.golang.org/a2a/v1\252\002\006A2a.V1' + _globals['DESCRIPTOR']._serialized_options = b'\n\ncom.a2a.v1B\010A2aProtoP\001Z\030google.golang.org/a2a/v1\242\002\003AXX\252\002\006A2a.V1\312\002\006A2a\\V1\342\002\022A2a\\V1\\GPBMetadata\352\002\007A2a::V1' _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._loaded_options = None _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_options = b'8\001' _globals['_SECURITY_SCHEMESENTRY']._loaded_options = None From 2648238ec1c421f80f6f4475b4fb87b76619f463 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 24 Jul 2025 14:29:44 +0000 Subject: [PATCH 06/17] Fix typos, add comments, address gemini feedback. --- src/a2a/client/client.py | 19 +++++- src/a2a/client/client_factory.py | 65 +++++++++++++++---- src/a2a/client/client_task_manager.py | 13 ++-- src/a2a/client/grpc_client.py | 21 +++--- src/a2a/client/jsonrpc_client.py | 1 + src/a2a/client/rest_client.py | 36 +--------- src/a2a/server/apps/rest/fastapi_app.py | 4 -- src/a2a/server/apps/rest/rest_app.py | 48 -------------- .../server/request_handlers/rest_handler.py | 18 ++--- 9 files changed, 97 insertions(+), 128 deletions(-) diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 5dfe0906..e058a799 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -144,7 +144,9 @@ class ClientConfig: """Whether client supports streaming""" polling: bool = False - """Whether client prefers to poll for updates from message:send""" + """Whether client prefers to poll for updates from message:send. It is + the callers job to check if the response is completed and if not run a + polling loop.""" httpx_client: httpx.AsyncClient | None = None """Http client to use to connect to agent.""" @@ -164,10 +166,10 @@ class ClientConfig: """Whether to use client transport preferences over server preferences. Recommended to use server preferences in most situations.""" - acceptedOutputModes: list[str] = dataclasses.field(default_factory=list) + accepted_outputModes: list[str] = dataclasses.field(default_factory=list) """The set of accepted output modes for the client.""" - pushNotificationConfigs: list[PushNotificationConfig] = dataclasses.field(default_factory=list) + push_notification_configs: list[PushNotificationConfig] = dataclasses.field(default_factory=list) """Push notification callbacks to use for every request.""" UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None @@ -197,6 +199,14 @@ async def send_message( *, context: ClientCallContext | None = None, ) -> AsyncIterator[ClientEvent | Message]: + """Sends a message to the server. + + This will automatically use the streaming or non-streaming approach + as supported by the server and the client config. Client will + aggregate update events and return an iterator of (`Task`,`Update`) + pairs, or a `Message`. Client will also send these values to any + configured `Consumer`s in the client. + """ pass yield @@ -255,9 +265,11 @@ async def get_card( pass async def add_event_consumer(self, consumer: Consumer): + """Attaches additional consumers to the `Client`""" self._consumers.append(consumer) async def add_request_middleware(self, middleware: ClientCallInterceptor): + """Attaches additional middleware to the `Client`""" self._middleware.append(middleware) async def consume( @@ -265,6 +277,7 @@ async def consume( event: tuple[Task, UpdateEvent] | Message | None, card: AgentCard, ): + """Processes the event via all the registered `Consumer`s.""" if not event: return for c in self._consumers: diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index f149d5d0..9f85c2b4 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -42,15 +42,31 @@ ] class ClientFactory: + """ClientFactory is used to generate the appropriate client for the agent. + + The factory is configured with a `ClientConfig` and optionally a list of + `Consumer`s to use for all generated `Client`s. The expected use is: + + factory = ClientFactory(config, consumers) + # Optionally register custom client implementations + factory.register('my_customer_transport', NewCustomTransportClient) + # Then with an agent card make a client with additional consumers and + # interceptors + client = factory.create(card, additional_consumers, interceptors) + # Now the client can be used the same regardless of transport and + # aligns client config with server capabilities. + """ def __init__( self, config: ClientConfig, - consumers: list[Consumer], + consumers: list[Consumer] = [], ): self._config = config self._consumers = consumers self._registry: dict[str, ClientProducer] = {} + # By default register the 3 core transports if in the config. + # Can be overridden with custom clients via the register method. if Transports.JSONRPC in self._config.supported_transports: self._registry[Transports.JSONRPC] = NewJsonRpcClient if Transports.RESTful in self._config.supported_transports: @@ -59,6 +75,7 @@ def __init__( self._registry[Transports.GRPC] = NewGrpcClient def register(self, label: str, generator: ClientProducer): + """Register a new client producer for a given transport label.""" self._registry[label] = generator def create( @@ -67,14 +84,31 @@ def create( consumers: list[Consumer] | None = None, interceptors: list[ClientCallInterceptor] | None = None, ) -> Client: + """Create a new `Client` for the provided `AgentCard`. + + Args: + card: An `AgentCard` defining the characteristics of the agent. + consumers: A list of `Consumer` methods to pass responses to. + interceptors: A list of interceptors to use for each request. These + are used for things like attaching credentials or http headers + to all outbound requests. + + Returns: + A `Client` object. + + Raises: + If there is no valid matching of the client configuration with the + server configuration, a `ValueError` is raised. + """ # Determine preferential transport - server_set = [card.preferredTransport or 'JSONRPC'] - if card.additionalInterfaces: + server_set = [card.preferred_transport or 'JSONRPC'] + if card.additional_interfaces: server_set.extend( - [x.transport for x in card.additionalInterfaces] + [x.transport for x in card.additional_interfaces] ) client_set = self._config.supported_transports or ['JSONRPC'] transport = None + # Two options, use the client ordering or the server ordering. if self._config.use_client_preference: for x in client_set: if x in server_set: @@ -86,9 +120,9 @@ def create( transport = x break if not transport: - raise Exception('no compatible transports found.') + raise ValueError('no compatible transports found.') if transport not in self._registry: - raise Exception(f'no client available for {transport}') + raise ValueError(f'no client available for {transport}') all_consumers = self._consumers if consumers: all_consumers.extend(consumers) @@ -97,15 +131,22 @@ def create( ) def minimal_agent_card(url: str, transports: list[str] = []) -> AgentCard: - """Generates a minimal card to simplify bootstrapping client creation.""" + """Generates a minimal card to simplify bootstrapping client creation. + + This minimal card is not viable itself to interact with the remote agent. + Instead this is a short hand way to take a known url and transport option + and interact with the get card endpoint of the agent server to get the + correct agent card. This pattern is necessary for gRPC based card access + as typically these servers won't expose a well known path card. + """ return AgentCard( url=url, - preferredTransport=transports[0] if transports else None, - additionalInterfaces=transports[1:] if len(transports) > 1 else [], - supportsAuthenticatedExtendedCard=True, + preferred_transport=transports[0] if transports else None, + additional_interfaces=transports[1:] if len(transports) > 1 else [], + supports_authenticated_extended_card=True, capabilities=AgentCapabilities(), - defaultInputModes=[], - defaultOutputModes=[], + default_input_modes=[], + default_output_modes=[], description='', skills=[], version='', diff --git a/src/a2a/client/client_task_manager.py b/src/a2a/client/client_task_manager.py index 733e51bd..99206cbb 100644 --- a/src/a2a/client/client_task_manager.py +++ b/src/a2a/client/client_task_manager.py @@ -27,7 +27,7 @@ class ClientTaskManager: def __init__( self, ): - """Initializes the TaskManager. + """Initializes the `TaskManager`. Args: task_id: The ID of the task, if known from the request. @@ -41,10 +41,9 @@ def __init__( self._context_id: str | None = None def get_task(self) -> Task | None: - """Retrieves the current task object, either from memory or the store. + """Retrieves the current task object, either from memory. - If `task_id` is set, it first checks the in-memory `_current_task`, - then attempts to load it from the `task_store`. + If `task_id` is set, it returns `_current_task` otherwise None. Returns: The `Task` object if found, otherwise `None`. @@ -69,7 +68,7 @@ async def save_task_event( The updated `Task` object after processing the event. Raises: - ServerError: If the task ID in the event conflicts with the TaskManager's ID + ClientError: If the task ID in the event conflicts with the TaskManager's ID when the TaskManager's ID is already set. """ if isinstance(event, Task): @@ -142,7 +141,7 @@ async def process(self, event: Event) -> Event: return event async def _save_task(self, task: Task) -> None: - """Saves the given task to the task store and updates the in-memory `_current_task`. + """Saves the given task to the `_current_task` and updated `_task_id` and `_context_id` Args: task: The `Task` object to save. @@ -155,7 +154,7 @@ async def _save_task(self, task: Task) -> None: self._context_id = task.contextId def update_with_message(self, message: Message, task: Task) -> Task: - """Updates a task object in memory by adding a new message to its history. + """Updates a task object adding a new message to its history. If the task has a message in its current status, that message is moved to the history first. diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py index 4dabdc71..a16cf42e 100644 --- a/src/a2a/client/grpc_client.py +++ b/src/a2a/client/grpc_client.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -#@trace_class(kind=SpanKind.CLIENT) +@trace_class(kind=SpanKind.CLIENT) class GrpcTransportClient: """Transport specific details for interacting with an A2A agent via gRPC.""" @@ -248,7 +248,6 @@ async def get_card( return card -#@trace_class(kind=SpanKind.CLIENT) class GrpcClient(Client): """GrpcClient provides the Client interface for the gRPC transport.""" @@ -276,30 +275,35 @@ async def send_message( *, context: ClientCallContext | None = None, ) -> AsyncIterator[ClientEvent | Message]: - # TODO: Set the request params from config + config = MessageSendConfiguration( + accepted_output_modes=self._config.accepted_output_modes, + blocking=not self._config.polling, + push_notification_config=( + self._config.push_notification_configs[0] + if self._config.push_notification_configs + else None + ), + ) if not self._config.streaming or not self._card.capabilities.streaming: - print("Using blocking interaction") response = await self._transport_client.send_message( MessageSendParams( message=request, - # TODO: set params + configuration=config, ), context=context, ) result = ( (response, None) if isinstance(response, Task) else response ) - # Spin off consumers - in thread, out of thread, etc? await self.consume(result, self._card) yield result return # Get Task tracker - print("Using streaming interactions") tracker = ClientTaskManager() async for event in self._transport_client.send_message_streaming( MessageSendParams( message=request, - # TODO: set params + configuration=config, ), context=context, ): @@ -399,4 +403,5 @@ def NewGrpcClient( consumers: list[Consumer], middleware: list[ClientCallInterceptor] ) -> Client: + """Generator for the `GrpcClient` implementation.""" return GrpcClient(card, config, consumers, middleware) diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py index 64035650..e791eac3 100644 --- a/src/a2a/client/jsonrpc_client.py +++ b/src/a2a/client/jsonrpc_client.py @@ -716,4 +716,5 @@ def NewJsonRpcClient( consumers: list[Consumer], middleware: list[ClientCallInterceptor] ) -> Client: + """Generator for the `JsonRpcClient` implementation.""" return JsonRpcClient(card, config, consumers, middleware) diff --git a/src/a2a/client/rest_client.py b/src/a2a/client/rest_client.py index 51c29383..339eeb95 100644 --- a/src/a2a/client/rest_client.py +++ b/src/a2a/client/rest_client.py @@ -94,41 +94,6 @@ async def _apply_interceptors( # TODO: Implement interceptors for other transports return final_request_payload, final_http_kwargs - @staticmethod - async def get_client_from_agent_card_url( - httpx_client: httpx.AsyncClient, - base_url: str, - agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, - http_kwargs: dict[str, Any] | None = None, - ) -> 'A2AClient': - """[deprecated] Fetches the public AgentCard and initializes an A2A client. - - This method will always fetch the public agent card. If an authenticated - or extended agent card is required, the A2ACardResolver should be used - directly to fetch the specific card, and then the A2AClient should be - instantiated with it. - - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - base_url: The base URL of the agent's host. - agent_card_path: The path to the agent card endpoint, relative to the base URL. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.get request when fetching the agent card. - - Returns: - An initialized `A2AClient` instance. - - Raises: - A2AClientHTTPError: If an HTTP error occurs fetching the agent card. - A2AClientJSONError: If the agent card response is invalid. - """ - agent_card: AgentCard = await A2ACardResolver( - httpx_client, base_url=base_url, agent_card_path=agent_card_path - ).get_agent_card( - http_kwargs=http_kwargs - ) # Fetches public card by default - return A2AClient(httpx_client=httpx_client, agent_card=agent_card) - async def send_message( self, request: MessageSendParams, @@ -759,4 +724,5 @@ def NewRestfulClient( consumers: list[Consumer], middleware: list[ClientCallInterceptor] ) -> Client: + """Generator for the `RestClient` implementation.""" return RestClient(card, config, consumers, middleware) diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index ed2579bd..28922a33 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -69,10 +69,6 @@ def build( app = FastAPI(**kwargs) router = APIRouter() for route, callback in self._handler.routes().items(): -<<<<<<< Updated upstream -======= - print("Attaching ", route) ->>>>>>> Stashed changes router.add_api_route( f'{rpc_url}{route[0]}', callback, diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index afc82038..810f3813 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -193,53 +193,6 @@ async def handle_authenticated_agent_card(self, request: Request) -> JSONRespons def routes(self) -> dict[Tuple[str, str], Callable[[Request],Any]]: routes = { -<<<<<<< Updated upstream - ('/v1/message:send', 'POST'): ( - functools.partial( - self._handle_request, - self.handler.on_message_send), - ), - ('/v1/message:stream', 'POST'): ( - functools.partial( - self._handle_streaming_request, - self.handler.on_message_send_stream), - ), - ('/v1/tasks/{id}:subscribe', 'POST'): ( - functools.partial( - self._handle_streaming_request, - self.handler.on_resubscribe_to_task), - ), - ('/v1/tasks/{id}', 'GET'): ( - functools.partial( - self._handle_request, - self.handler.on_get_task), - ), - ('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): ( - functools.partial( - self._handle_request, - self.handler.get_push_notification), - ), - ('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): ( - functools.partial( - self._handle_request, - self.handler.set_push_notification), - ), - ('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): ( - functools.partial( - self._handle_request, - self.handler.list_push_notifications), - ), - ('/v1/tasks', 'GET'): ( - functools.partial( - self._handle_request, - self.handler.list_tasks), - ), - } - if self.agent_card.supportsAuthenticatedExtendedCard: - routes['/v1/card'] = ( - self.handle_authenticated_agent_card, - 'GET') -======= ('/v1/message:send', 'POST'): functools.partial( self._handle_request, self.handler.on_message_send @@ -278,5 +231,4 @@ def routes(self) -> dict[Tuple[str, str], Callable[[Request],Any]]: } if self.agent_card.supportsAuthenticatedExtendedCard: routes[('/v1/card', 'GET')] = self.handle_authenticated_agent_card ->>>>>>> Stashed changes return routes diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 71a1ba80..677148f8 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -117,10 +117,6 @@ async def on_message_send_stream( """ try: body = await request.body() -<<<<<<< Updated upstream -======= - print('Request body', body) ->>>>>>> Stashed changes params = a2a_pb2.SendMessageRequest() Parse(body, params) # Transform the proto object to the python internal objects @@ -223,7 +219,7 @@ async def get_push_notification( if push_id: params = GetTaskPushNotificationConfigParams(id=task_id, push_id=push_id) else: - params = TaskIdParams['id'] + params = TaskIdParams(id=task_id) config = await self.request_handler.on_get_task_push_notification_config( params, context ) @@ -265,10 +261,10 @@ async def set_push_notification( body = await request.body() params = a2a_pb2.TaskPushNotificationConfig() Parse(body, params) - params = TaskPushNotificationConfig.validate_model(body) + params = TaskPushNotificationConfig.model_validate(body) a2a_request = proto_utils.FromProto.task_push_notification_config( params, - ), + ) config = await self.request_handler.on_set_task_push_notification_config( a2a_request, context ) @@ -299,10 +295,10 @@ async def on_get_task( """ try: task_id = request.path_params['id'] - historyLength = None - if 'historyLength' in request.query_params: - historyLength = request.query_params['historyLength'] - params = TaskQueryParams(id=task_id, historyLength=historyLength) + history_length = request.query_params.get('historyLength', None) + if historyLength: + history_length = int(history_length) + params = TaskQueryParams(id=task_id, history_length=history_length) task = await self.request_handler.on_get_task(params, context) if task: return MessageToJson(proto_utils.ToProto.task(task)) From 069fff6ecd3790e49d289459b27aed4ec8888d35 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 24 Jul 2025 14:36:17 +0000 Subject: [PATCH 07/17] More updates from gemini review. --- src/a2a/client/client_factory.py | 2 +- src/a2a/client/client_task_manager.py | 8 +++----- src/a2a/client/errors.py | 13 +++++++++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 9f85c2b4..5ec12f17 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -123,7 +123,7 @@ def create( raise ValueError('no compatible transports found.') if transport not in self._registry: raise ValueError(f'no client available for {transport}') - all_consumers = self._consumers + all_consumers = self._consumers.copy() if consumers: all_consumers.extend(consumers) return self._registry[transport]( diff --git a/src/a2a/client/client_task_manager.py b/src/a2a/client/client_task_manager.py index 99206cbb..305c4e01 100644 --- a/src/a2a/client/client_task_manager.py +++ b/src/a2a/client/client_task_manager.py @@ -11,7 +11,7 @@ TaskStatusUpdateEvent, ) from a2a.utils import append_artifact_to_task -from a2a.utils.errors import ServerError +from a2a.client.errors import A2AClientInvalidArgsError logger = logging.getLogger(__name__) @@ -73,10 +73,8 @@ async def save_task_event( """ if isinstance(event, Task): if self._current_task: - raise ClientError( - error=InvalidParamsError( - message="Task is already set, create new manager for new tasks." - ) + raise A2AClientInvalidArgsError( + "Task is already set, create new manager for new tasks." ) await self._save_task(event) return event diff --git a/src/a2a/client/errors.py b/src/a2a/client/errors.py index 5fe5512a..1d49eb55 100644 --- a/src/a2a/client/errors.py +++ b/src/a2a/client/errors.py @@ -44,3 +44,16 @@ def __init__(self, message: str): """ self.message = message super().__init__(f'Timeout Error: {message}') + + +class A2AClientInvalidArgsError(A2AClientError): + """Client exception for timeout errors during a request.""" + + def __init__(self, message: str): + """Initializes the A2AClientInvalidArgsError. + + Args: + message: A descriptive error message. + """ + self.message = message + super().__init__(f'Invalid arguments error: {message}') From 38b516a5b83dc6b98b6ba6dd3836e4444735ca0b Mon Sep 17 00:00:00 2001 From: pstephengoogle Date: Thu, 24 Jul 2025 08:39:44 -0600 Subject: [PATCH 08/17] Update src/a2a/client/jsonrpc_client.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/client/jsonrpc_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py index e791eac3..cef541be 100644 --- a/src/a2a/client/jsonrpc_client.py +++ b/src/a2a/client/jsonrpc_client.py @@ -504,12 +504,12 @@ async def get_card( A2AClientJSONError: If the response body cannot be decoded as JSON or validated. """ # If we don't have the public card, try to get that first. - card = self.card + card = self.agent_card if not card: resolver = A2ACardResolver(self.httpx_client, self.url) - card = resolver.get_agent_card(http_kwargs=http_kwargs) + card = await resolver.get_agent_card(http_kwargs=http_kwargs) self._needs_extended_card = card.supportsAuthenticatedExtendedCard - self.card = card + self.agent_card = card if not self._needs_extended_card: return card From 8b91b3440d7055501ad7da22719fe453c1b21c13 Mon Sep 17 00:00:00 2001 From: pstephengoogle Date: Thu, 24 Jul 2025 08:40:12 -0600 Subject: [PATCH 09/17] Update src/a2a/client/client.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/client/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index e058a799..47d20994 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -207,7 +207,6 @@ async def send_message( pairs, or a `Message`. Client will also send these values to any configured `Consumer`s in the client. """ - pass yield @abstractmethod From 0d220a3e4cbfc7f418c2a8e984314e23ec45ed03 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 24 Jul 2025 14:42:15 +0000 Subject: [PATCH 10/17] Fix merge conflict --- src/a2a/server/apps/rest/rest_app.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index 810f3813..2d3460f4 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -139,18 +139,11 @@ async def event_generator( ) -> AsyncGenerator[dict[str, str]]: async for item in stream: yield {'data': item} -<<<<<<< Updated upstream - return EventSourceResponse(event_generator(method(request, call_context))) - except Exception as e: - # Since the stream has started, we can't return a JSONResponse. - # Instead, we runt the error handling logic (provides logging) -======= return EventSourceResponse( event_generator(method(request, call_context))) except Exception as e: # Since the stream has started, we can't return a JSONResponse. # Instead, we run the error handling logic (provides logging) ->>>>>>> Stashed changes # and reraise the error and let server framework manage self._handle_error(e) raise e From eae6e21fc9607235923f9ba237aa3f001a83d4ab Mon Sep 17 00:00:00 2001 From: pstephengoogle Date: Thu, 24 Jul 2025 08:46:23 -0600 Subject: [PATCH 11/17] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/client/jsonrpc_client.py | 4 ++-- src/a2a/client/rest_client.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py index cef541be..ddec1b37 100644 --- a/src/a2a/client/jsonrpc_client.py +++ b/src/a2a/client/jsonrpc_client.py @@ -640,7 +640,7 @@ async def cancel_task( response = await self._transport_client.cancel_task( CancelTaskRequest( params=request, - id=srt(uuid4()), + id=str(uuid4()), ), http_kwargs=self.get_http_args(context), context=context, @@ -691,7 +691,7 @@ async def resubscribe( ) async for event in self._transport_client.resubscribe( TaskResubscriptionRequest( - params=TaskIdParams, + params=request, id=str(uuid4()), ), http_kwargs=self.get_http_args(context), diff --git a/src/a2a/client/rest_client.py b/src/a2a/client/rest_client.py index 339eeb95..1aa0c40d 100644 --- a/src/a2a/client/rest_client.py +++ b/src/a2a/client/rest_client.py @@ -169,7 +169,7 @@ async def send_message_streaming( """ pb = a2a_pb2.SendMessageRequest( request=proto_utils.ToProto.message(request.message), - configuration=proto_utils.ToProto.message_send_configuration( + configuration=proto_utils.ToProto.send_message_config( request.configuration ), metadata=( @@ -359,7 +359,7 @@ async def cancel_task( context, ) response_data = await self._send_post_request( - f'/v1/tasks/{request.taskId}:cancel', + f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs ) @@ -445,7 +445,7 @@ async def get_task_callback( context, ) response_data = await self._send_get_request( - f'/v1/tasks/{request.taskId}/pushNotificationConfigs/{request.pushNotificationId}', + f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', {}, modified_kwargs ) @@ -536,12 +536,12 @@ async def get_card( A2AClientJSONError: If the response body cannot be decoded as JSON or validated. """ # If we don't have the public card, try to get that first. - card = self.card + card = self.agent_card if not card: resolver = A2ACardResolver(self.httpx_client, self.url) - card = resolver.get_agent_card(http_kwargs=http_kwargs) + card = await resolver.get_agent_card(http_kwargs=http_kwargs) self._needs_extended_card = card.supportsAuthenticatedExtendedCard - self.card = card + self.agent_card = card if not self._needs_extended_card: return card @@ -627,7 +627,7 @@ async def send_message( ): # Update task, check for errors, etc. if isinstance(event, Message): - yield result + yield event return await tracker.process(event) result = ( @@ -688,7 +688,7 @@ async def get_task_callback( http_kwargs=self.get_http_args(context), context=context, ) - return response.result + return response async def resubscribe( self, @@ -701,7 +701,7 @@ async def resubscribe( 'client and/or server do not support resubscription.' ) async for event in self._transport_client.resubscribe( - TaskIdParams, + request, http_kwargs=self.get_http_args(context), context=context, ): From e19feb96d68dc739f00a7170ba4a21518f41fdc7 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 24 Jul 2025 14:51:44 +0000 Subject: [PATCH 12/17] More fixes from Gemini --- src/a2a/client/jsonrpc_client.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py index ddec1b37..668b26e7 100644 --- a/src/a2a/client/jsonrpc_client.py +++ b/src/a2a/client/jsonrpc_client.py @@ -508,15 +508,12 @@ async def get_card( if not card: resolver = A2ACardResolver(self.httpx_client, self.url) card = await resolver.get_agent_card(http_kwargs=http_kwargs) - self._needs_extended_card = card.supportsAuthenticatedExtendedCard + self._needs_extended_card = card.supports_authenticated_extended_card self.agent_card = card if not self._needs_extended_card: return card - if not request.id: - request.id = str(uuid4()) - # Apply interceptors before sending payload, modified_kwargs = await self._apply_interceptors( 'card/getAuthenticated', From fcd98282c387d44478ff003f69b983da3723eb5c Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Thu, 24 Jul 2025 16:46:49 +0100 Subject: [PATCH 13/17] Spelling --- .github/actions/spelling/allow.txt | 18 +----------------- src/a2a/client/jsonrpc_client.py | 26 ++++++++++++++------------ 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 043ec104..39370e0f 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -17,23 +17,6 @@ AServers AService AStarlette AUser -DSNs -EUR -GBP -GVsb -INR -JPY -JSONRPCt -Llm -POSTGRES -RUF -Tful -aconnect -adk -agentic -aio -aiomysql -aproject autouse backticks cla @@ -82,6 +65,7 @@ pyi pypistats pyupgrade pyversions +redef respx resub RUF diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py index 668b26e7..f40693ac 100644 --- a/src/a2a/client/jsonrpc_client.py +++ b/src/a2a/client/jsonrpc_client.py @@ -8,12 +8,11 @@ import httpx from httpx_sse import SSEError, aconnect_sse -from pydantic import ValidationError -from a2a.client.client import Client, ClientConfig, A2ACardResolver, Consumer +from a2a.client.client import A2ACardResolver, Client, ClientConfig, Consumer +from a2a.client.client_task_manager import ClientTaskManager from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.client.client_task_manager import ClientTaskManager from a2a.types import ( AgentCard, CancelTaskRequest, @@ -34,8 +33,8 @@ SetTaskPushNotificationConfigResponse, Task, TaskIdParams, - TaskQueryParams, TaskPushNotificationConfig, + TaskQueryParams, TaskResubscriptionRequest, ) from a2a.utils.constants import ( @@ -87,7 +86,9 @@ def __init__( # card. self._needs_extended_card = ( not agent_card.supportsAuthenticatedExtendedCard - if agent_card else True) + if agent_card + else True + ) async def _apply_interceptors( self, @@ -379,7 +380,7 @@ async def set_task_callback( 'tasks/pushNotificationConfig/set', request.model_dump(mode='json', exclude_none=True), http_kwargs, - context + context, ) response_data = await self._send_request(payload, modified_kwargs) return SetTaskPushNotificationConfigResponse.model_validate( @@ -448,7 +449,6 @@ async def resubscribe( A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. """ - # Apply interceptors before sending payload, modified_kwargs = await self._apply_interceptors( 'tasks/resubscribe', @@ -508,7 +508,9 @@ async def get_card( if not card: resolver = A2ACardResolver(self.httpx_client, self.url) card = await resolver.get_agent_card(http_kwargs=http_kwargs) - self._needs_extended_card = card.supports_authenticated_extended_card + self._needs_extended_card = ( + card.supports_authenticated_extended_card + ) self.agent_card = card if not self._needs_extended_card: @@ -533,7 +535,7 @@ class JsonRpcClient(Client): """JsonRpcClient is the implementation of the JSONRPC A2A client. This client proxies requests to the JsonRpcTransportClient implementation - and manages the JSONRPC specific details. If passing additional arguements + and manages the JSONRPC specific details. If passing additional arguments in the http.post command, these should be attached to the ClientCallContext under the dictionary key 'http_kwargs'. """ @@ -606,8 +608,7 @@ async def send_message( await tracker.process(result) result = ( tracker.get_task(), - None if isinstance(result, Task) - else result + None if isinstance(result, Task) else result, ) await self.consume(result, self._card) yield result @@ -707,11 +708,12 @@ async def get_card( context=context, ) + def NewJsonRpcClient( card: AgentCard, config: ClientConfig, consumers: list[Consumer], - middleware: list[ClientCallInterceptor] + middleware: list[ClientCallInterceptor], ) -> Client: """Generator for the `JsonRpcClient` implementation.""" return JsonRpcClient(card, config, consumers, middleware) From 1ae30ee29f63eba3de938e634ed1508f45d22cf6 Mon Sep 17 00:00:00 2001 From: Holt Skinner Date: Thu, 24 Jul 2025 16:49:35 +0100 Subject: [PATCH 14/17] Formatting --- src/a2a/client/__init__.py | 69 ++++++------ src/a2a/client/client.py | 32 +++--- src/a2a/client/client_factory.py | 35 ++---- src/a2a/client/client_task_manager.py | 9 +- src/a2a/client/grpc_client.py | 13 ++- src/a2a/client/rest_client.py | 84 +++++++-------- src/a2a/server/apps/rest/__init__.py | 1 + src/a2a/server/apps/rest/fastapi_app.py | 6 +- src/a2a/server/apps/rest/rest_app.py | 101 +++++++++--------- src/a2a/server/request_handlers/__init__.py | 3 +- .../server/request_handlers/rest_handler.py | 67 ++++++------ src/a2a/utils/__init__.py | 3 +- src/a2a/utils/proto_utils.py | 18 ++-- src/a2a/utils/transports.py | 8 +- 14 files changed, 211 insertions(+), 238 deletions(-) diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 0e0eaba7..40a326a4 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -7,41 +7,42 @@ CredentialService, InMemoryContextCredentialStore, ) +from a2a.client.client import ( + A2ACardResolver, + Client, + ClientConfig, + ClientEvent, + Consumer, +) +from a2a.client.client_factory import ( + ClientFactory, + ClientProducer, + minimal_agent_card, +) from a2a.client.errors import ( A2AClientError, A2AClientHTTPError, A2AClientJSONError, A2AClientTimeoutError, ) +from a2a.client.grpc_client import ( + GrpcClient, + GrpcTransportClient, + NewGrpcClient, +) +from a2a.client.helpers import create_text_message_object from a2a.client.jsonrpc_client import ( JsonRpcClient, JsonRpcTransportClient, NewJsonRpcClient, ) -from a2a.client.grpc_client import ( - GrpcTransportClient, - GrpcClient, - NewGrpcClient, -) +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.rest_client import ( - RestTransportClient, - RestClient, NewRestfulClient, + RestClient, + RestTransportClient, ) -from a2a.client.helpers import create_text_message_object -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.client.client import ( - A2ACardResolver, - Client, - ClientConfig, - Consumer, - ClientEvent, -) -from a2a.client.client_factory import ( - ClientFactory, - ClientProducer, - minimal_agent_card -) + # For backward compatability define this alias. This will be deprecated in # a future release. @@ -71,32 +72,32 @@ def __init__(self, *args, **kwargs): __all__ = [ 'A2ACardResolver', + 'A2AClient', # for backward compatability 'A2AClientError', 'A2AClientHTTPError', 'A2AClientJSONError', 'A2AClientTimeoutError', + 'A2AGrpcClient', # for backward compatability 'AuthInterceptor', + 'Client', 'ClientCallContext', 'ClientCallInterceptor', - 'Consumer', - 'CredentialService', - 'InMemoryContextCredentialStore', - 'create_text_message_object', - 'A2AClient', # for backward compatability - 'A2AGrpcClient', # for backward compatability - 'Client', + 'ClientConfig', 'ClientEvent', 'ClientFactory', - 'ClientConfig', 'ClientProducer', - 'GrpcTransportClient', + 'Consumer', + 'CredentialService', 'GrpcClient', - 'NewGrpcClient', + 'GrpcTransportClient', + 'InMemoryContextCredentialStore', 'JsonRpcClient', 'JsonRpcTransportClient', + 'NewGrpcClient', 'NewJsonRpcClient', - 'minimal_agent_card', - 'RestTransportClient', - 'RestClient', 'NewRestfulClient', + 'RestClient', + 'RestTransportClient', + 'create_text_message_object', + 'minimal_agent_card', ] diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 47d20994..450e42ab 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -3,15 +3,14 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterator -from typing import Any, Callable, Coroutine -from uuid import uuid4 +from collections.abc import AsyncIterator, Callable, Coroutine +from typing import Any import httpx -from httpx_sse import SSEError, aconnect_sse from pydantic import ValidationError + # Attempt to import the optional module try: from grpc.aio import Channel @@ -19,15 +18,15 @@ # If grpc.aio is not available, define a dummy type for type checking. # This dummy type will only be used by type checkers. if TYPE_CHECKING: + class Channel: # type: ignore[no-redef] pass else: - Channel = None # At runtime, pd will be None if the import failed. + Channel = None # At runtime, pd will be None if the import failed. from a2a.client.errors import ( A2AClientHTTPError, A2AClientJSONError, - A2AClientTimeoutError, ) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.types import ( @@ -36,14 +35,13 @@ class Channel: # type: ignore[no-redef] Message, PushNotificationConfig, Task, + TaskArtifactUpdateEvent, TaskIdParams, - TaskQueryParams, TaskPushNotificationConfig, + TaskQueryParams, TaskStatusUpdateEvent, - TaskArtifactUpdateEvent, ) from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH -from a2a.utils.telemetry import SpanKind, trace_class logger = logging.getLogger(__name__) @@ -136,6 +134,7 @@ async def get_agent_card( return agent_card + @dataclasses.dataclass class ClientConfig: """Configuration class for the A2A Client Factory""" @@ -154,7 +153,7 @@ class ClientConfig: grpc_channel_factory: Callable[[str], Channel] | None = None """Generates a grpc connection channel for a given url.""" - supported_transports: list[str] = dataclasses.field(default_factory=list) + supported_transports: list[str] = dataclasses.field(default_factory=list) """Ordered list of transports for connecting to agent (in order of preference). Empty implies JSONRPC only. @@ -166,12 +165,15 @@ class ClientConfig: """Whether to use client transport preferences over server preferences. Recommended to use server preferences in most situations.""" - accepted_outputModes: list[str] = dataclasses.field(default_factory=list) + accepted_outputModes: list[str] = dataclasses.field(default_factory=list) """The set of accepted output modes for the client.""" - push_notification_configs: list[PushNotificationConfig] = dataclasses.field(default_factory=list) + push_notification_configs: list[PushNotificationConfig] = dataclasses.field( + default_factory=list + ) """Push notification callbacks to use for every request.""" + UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None # Alias for emitted events from client ClientEvent = tuple[Task, UpdateEvent] @@ -183,7 +185,6 @@ class ClientConfig: class Client(ABC): - def __init__( self, consumers: list[Consumer] = [], @@ -252,14 +253,11 @@ async def resubscribe( *, context: ClientCallContext | None = None, ) -> AsyncIterator[Task | Message]: - pass yield @abstractmethod async def get_card( - self, - *, - context: ClientCallContext | None = None + self, *, context: ClientCallContext | None = None ) -> AgentCard: pass diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 5ec12f17..2dd37546 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -1,33 +1,20 @@ from __future__ import annotations -import json -import logging - -from collections.abc import AsyncGenerator -from typing import Any, TYPE_CHECKING, Callable - -import httpx -from httpx_sse import SSEError, aconnect_sse -from pydantic import ValidationError +import logging -from a2a.utils import Transports +from collections.abc import Callable from a2a.client.client import Client, ClientConfig, Consumer -from a2a.client.jsonrpc_client import NewJsonRpcClient from a2a.client.grpc_client import NewGrpcClient +from a2a.client.jsonrpc_client import NewJsonRpcClient +from a2a.client.middleware import ClientCallInterceptor from a2a.client.rest_client import NewRestfulClient -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError -from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.types import ( AgentCapabilities, AgentCard, - Message, - Task, - TaskIdParams, - TaskQueryParams, - GetTaskPushNotificationConfigParams, - TaskPushNotificationConfig, ) +from a2a.utils import Transports + logger = logging.getLogger(__name__) @@ -36,11 +23,12 @@ AgentCard | str, ClientConfig, list[Consumer], - list[ClientCallInterceptor] + list[ClientCallInterceptor], ], - Client + Client, ] + class ClientFactory: """ClientFactory is used to generate the appropriate client for the agent. @@ -103,9 +91,7 @@ def create( # Determine preferential transport server_set = [card.preferred_transport or 'JSONRPC'] if card.additional_interfaces: - server_set.extend( - [x.transport for x in card.additional_interfaces] - ) + server_set.extend([x.transport for x in card.additional_interfaces]) client_set = self._config.supported_transports or ['JSONRPC'] transport = None # Two options, use the client ordering or the server ordering. @@ -130,6 +116,7 @@ def create( card, self._config, all_consumers, interceptors ) + def minimal_agent_card(url: str, transports: list[str] = []) -> AgentCard: """Generates a minimal card to simplify bootstrapping client creation. diff --git a/src/a2a/client/client_task_manager.py b/src/a2a/client/client_task_manager.py index 305c4e01..acccf313 100644 --- a/src/a2a/client/client_task_manager.py +++ b/src/a2a/client/client_task_manager.py @@ -1,8 +1,8 @@ import logging +from a2a.client.errors import A2AClientInvalidArgsError from a2a.server.events.event_queue import Event from a2a.types import ( - InvalidParamsError, Message, Task, TaskArtifactUpdateEvent, @@ -11,7 +11,6 @@ TaskStatusUpdateEvent, ) from a2a.utils import append_artifact_to_task -from a2a.client.errors import A2AClientInvalidArgsError logger = logging.getLogger(__name__) @@ -74,7 +73,7 @@ async def save_task_event( if isinstance(event, Task): if self._current_task: raise A2AClientInvalidArgsError( - "Task is already set, create new manager for new tasks." + 'Task is already set, create new manager for new tasks.' ) await self._save_task(event) return event @@ -101,7 +100,9 @@ async def save_task_event( ) if isinstance(event, TaskStatusUpdateEvent): logger.debug( - 'Updating task %s status to: %s', event.taskId, event.status.state + 'Updating task %s status to: %s', + event.taskId, + event.status.state, ) if event.status.message: if not task.history: diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py index 84393f97..5a002a2d 100644 --- a/src/a2a/client/grpc_client.py +++ b/src/a2a/client/grpc_client.py @@ -17,11 +17,11 @@ Client, ClientCallContext, ClientConfig, - Consumer, ClientEvent, + Consumer, ) -from a2a.client.middleware import ClientCallInterceptor from a2a.client.client_task_manager import ClientTaskManager +from a2a.client.middleware import ClientCallInterceptor from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, @@ -64,8 +64,7 @@ def __init__( # If they don't provide an agent card, but do have a stub, lookup the # card from the stub. self._needs_extended_card = ( - agent_card.supportsAuthenticatedExtendedCard - if agent_card else True + agent_card.supportsAuthenticatedExtendedCard if agent_card else True ) async def send_message( @@ -251,7 +250,7 @@ async def get_card( card_pb = await self.stub.GetAgentCard( a2a_pb2.GetAgentCardRequest(), ) - card = proto_utils.FromProto.agent_card(card_pb) + card = proto_utils.FromProto.agent_card(card_pb) self.agent_card = card self._needs_extended_card = False return card @@ -324,7 +323,7 @@ async def send_message( await tracker.process(event) result = ( tracker.get_task(), - None if isinstance(event, Task) else event + None if isinstance(event, Task) else event, ) await self.consume(result, self._card) yield result @@ -410,7 +409,7 @@ def NewGrpcClient( card: AgentCard, config: ClientConfig, consumers: list[Consumer], - middleware: list[ClientCallInterceptor] + middleware: list[ClientCallInterceptor], ) -> Client: """Generator for the `GrpcClient` implementation.""" return GrpcClient(card, config, consumers, middleware) diff --git a/src/a2a/client/rest_client.py b/src/a2a/client/rest_client.py index 1aa0c40d..140a8b01 100644 --- a/src/a2a/client/rest_client.py +++ b/src/a2a/client/rest_client.py @@ -3,17 +3,17 @@ from collections.abc import AsyncGenerator, AsyncIterator from typing import Any -from uuid import uuid4 import httpx +from google.protobuf.json_format import MessageToDict, Parse from httpx_sse import SSEError, aconnect_sse -from pydantic import ValidationError -from a2a.client.client import Client, ClientConfig, A2ACardResolver, Consumer +from a2a.client.client import A2ACardResolver, Client, ClientConfig, Consumer +from a2a.client.client_task_manager import ClientTaskManager from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError from a2a.client.middleware import ClientCallContext, ClientCallInterceptor -from a2a.client.client_task_manager import ClientTaskManager +from a2a.grpc import a2a_pb2 from a2a.types import ( AgentCard, GetTaskPushNotificationConfigParams, @@ -22,17 +22,12 @@ Task, TaskArtifactUpdateEvent, TaskIdParams, - TaskQueryParams, TaskPushNotificationConfig, + TaskQueryParams, TaskStatusUpdateEvent, ) -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, -) -from a2a.utils.telemetry import SpanKind, trace_class -from a2a.grpc import a2a_pb2 from a2a.utils import proto_utils -from google.protobuf.json_format import Parse, MessageToDict +from a2a.utils.telemetry import SpanKind, trace_class logger = logging.getLogger(__name__) @@ -69,7 +64,7 @@ def __init__( else: raise ValueError('Must provide either agent_card or url') # If the url ends in / remove it as this is added by the routes - if self.url.endswith("/"): + if self.url.endswith('/'): self.url = self.url[:-1] self.httpx_client = httpx_client self.agent_card = agent_card @@ -80,7 +75,9 @@ def __init__( # card. self._needs_extended_card = ( not agent_card.supportsAuthenticatedExtendedCard - if agent_card else True) + if agent_card + else True + ) async def _apply_interceptors( self, @@ -123,7 +120,8 @@ async def send_message( ), metadata=( proto_utils.ToProto.metadata(request.metadata) - if request.metadata else None + if request.metadata + else None ), ) payload = MessageToDict(pb) @@ -134,9 +132,7 @@ async def send_message( context, ) response_data = await self._send_post_request( - '/v1/message:send', - payload, - modified_kwargs + '/v1/message:send', payload, modified_kwargs ) response_pb = a2a_pb2.SendMessageResponse() Parse(response_data, response_pb) @@ -148,7 +144,9 @@ async def send_message_streaming( *, http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> AsyncGenerator[Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message]: + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: """Sends a streaming message request to the agent and yields responses as they arrive. This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. @@ -174,7 +172,8 @@ async def send_message_streaming( ), metadata=( proto_utils.ToProto.metadata(request.metadata) - if request.metadata else None + if request.metadata + else None ), ) payload = MessageToDict(pb) @@ -236,7 +235,7 @@ async def _send_post_request( response = await self.httpx_client.post( f'{self.url}{target}', json=rpc_request_payload, - **(http_kwargs or {}) + **(http_kwargs or {}), ) response.raise_for_status() return response.json() @@ -274,7 +273,7 @@ async def _send_get_request( response = await self.httpx_client.get( f'{self.url}{target}', params=query_params, - **(http_kwargs or {}) + **(http_kwargs or {}), ) response.raise_for_status() return response.json() @@ -317,10 +316,10 @@ async def get_task( ) response_data = await self._send_get_request( f'/v1/tasks/{request.taskId}', - { - 'historyLength': request.historyLength - } if request.historyLength else {}, - modified_kwargs + {'historyLength': request.historyLength} + if request.historyLength + else {}, + modified_kwargs, ) task = a2a_pb2.Task() Parse(response_data, task) @@ -348,9 +347,7 @@ async def cancel_task( A2AClientHTTPError: If an HTTP error occurs during the request. A2AClientJSONError: If the response body cannot be decoded as JSON or validated. """ - pb = a2a_pb2.CancelTaskRequest( - name=f'tasks/{request.id}' - ) + pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') payload = MessageToDict(pb) # Apply interceptors before sending payload, modified_kwargs = await self._apply_interceptors( @@ -359,9 +356,7 @@ async def cancel_task( context, ) response_data = await self._send_post_request( - f'/v1/tasks/{request.id}:cancel', - payload, - modified_kwargs + f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs ) task = a2a_pb2.Task() Parse(response_data, task) @@ -399,14 +394,12 @@ async def set_task_callback( payload = MessageToDict(pb) # Apply interceptors before sending payload, modified_kwargs = await self._apply_interceptors( - payload, - http_kwargs, - context + payload, http_kwargs, context ) response_data = await self._send_post_request( f'/v1/tasks/{request.taskId}/pushNotificationConfigs/', payload, - modified_kwargs + modified_kwargs, ) config = a2a_pb2.TaskPushNotificationConfig() Parse(response_data, config) @@ -447,7 +440,7 @@ async def get_task_callback( response_data = await self._send_get_request( f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', {}, - modified_kwargs + modified_kwargs, ) config = a2a_pb2.TaskPushNotificationConfig() Parse(response_data, config) @@ -459,7 +452,9 @@ async def resubscribe( *, http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> AsyncGenerator[Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message]: + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: """Reconnects to get task updates This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. @@ -553,9 +548,8 @@ async def get_card( context, ) response_data = await self._send_get_request( - '/v1/card/get', - {}, - modified_kwargs) + '/v1/card/get', {}, modified_kwargs + ) card = AgentCard.model_validate(response_data) self.card = card self._needs_extended_card = False @@ -610,9 +604,7 @@ async def send_message( context=context, ) result = ( - response - if isinstance(response, Message) - else (response, None) + response if isinstance(response, Message) else (response, None) ) await self.consume(result, self._card) yield result @@ -632,8 +624,7 @@ async def send_message( await tracker.process(event) result = ( tracker.get_task(), - None if isinstance(event, Task) - else event + None if isinstance(event, Task) else event, ) await self.consume(result, self._card) yield result @@ -718,11 +709,12 @@ async def get_card( context=context, ) + def NewRestfulClient( card: AgentCard, config: ClientConfig, consumers: list[Consumer], - middleware: list[ClientCallInterceptor] + middleware: list[ClientCallInterceptor], ) -> Client: """Generator for the `RestClient` implementation.""" return RestClient(card, config, consumers, middleware) diff --git a/src/a2a/server/apps/rest/__init__.py b/src/a2a/server/apps/rest/__init__.py index 81ee7b7a..c57b0f38 100644 --- a/src/a2a/server/apps/rest/__init__.py +++ b/src/a2a/server/apps/rest/__init__.py @@ -3,6 +3,7 @@ from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication from a2a.server.apps.rest.rest_app import RESTApplication + __all__ = [ 'A2ARESTFastAPIApplication', 'RESTApplication', diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index e8622d75..76fa887f 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -2,7 +2,7 @@ from typing import Any -from fastapi import FastAPI, Request, Response, APIRouter +from fastapi import APIRouter, FastAPI, Request, Response from a2a.server.apps.jsonrpc.jsonrpc_app import ( CallContextBuilder, @@ -70,9 +70,7 @@ def build( router = APIRouter() for route, callback in self._handler.routes().items(): router.add_api_route( - f'{rpc_url}{route[0]}', - callback, - methods=[route[1]] + f'{rpc_url}{route[0]}', callback, methods=[route[1]] ) @router.get(f'{rpc_url}{agent_card_url}') diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index baeed93a..fa9076db 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -1,42 +1,33 @@ -import contextlib +import functools import json import logging import traceback -import functools -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable -from typing import Any, Tuple, Callable -from fastapi import FastAPI -from pydantic import BaseModel, ValidationError +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable +from typing import Any +from pydantic import ValidationError from sse_starlette.sse import EventSourceResponse -from starlette.applications import Starlette -from starlette.authentication import BaseUser from starlette.requests import Request -from starlette.responses import JSONResponse, Response +from starlette.responses import JSONResponse -from a2a.auth.user import UnauthenticatedUser -from a2a.auth.user import User as A2AUser +from a2a.server.apps.jsonrpc import ( + CallContextBuilder, + DefaultCallContextBuilder, +) from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.rest_handler import ( RESTHandler, ) -from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import ( - A2AError, AgentCard, - JSONParseError, - UnsupportedOperationError, InternalError, InvalidRequestError, + JSONParseError, + UnsupportedOperationError, ) from a2a.utils.errors import MethodNotImplementedError -from a2a.server.apps.jsonrpc import ( - CallContextBuilder, - StarletteUserProxy, - DefaultCallContextBuilder -) logger = logging.getLogger(__name__) @@ -102,23 +93,21 @@ def _handle_error(self, error: Exception) -> JSONResponse: traceback.print_exc() if isinstance(error, MethodNotImplementedError): return self._generate_error_response(UnsupportedOperationError()) - elif isinstance(error, json.decoder.JSONDecodeError): + if isinstance(error, json.decoder.JSONDecodeError): return self._generate_error_response( JSONParseError(message=str(error)) ) - elif isinstance(error, ValidationError): + if isinstance(error, ValidationError): return self._generate_error_response( InvalidRequestError(data=json.loads(error.json())), ) logger.error(f'Unhandled exception: {error}') - return self._generate_error_response( - InternalError(message=str(error)) - ) + return self._generate_error_response(InternalError(message=str(error))) async def _handle_request( self, method: Callable[[Request, ServerCallContext], Awaitable[str]], - request: Request + request: Request, ) -> JSONResponse: try: call_context = self._context_builder.build(request) @@ -130,18 +119,21 @@ async def _handle_request( async def _handle_streaming_request( self, method: Callable[[Request, ServerCallContext], AsyncIterator[str]], - request: Request + request: Request, ) -> EventSourceResponse: try: call_context = self._context_builder.build(request) + async def event_generator( stream: AsyncGenerator[str], ) -> AsyncGenerator[dict[str, str]]: async for item in stream: yield {'data': item} + return EventSourceResponse( - event_generator(method(request, call_context))) - except Exception as e: + event_generator(method(request, call_context)) + ) + except Exception: # Since the stream has started, we can't return a JSONResponse. # Instead, we run the error handling logic (provides logging) return EventSourceResponse( @@ -169,7 +161,9 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse: self.agent_card.model_dump(mode='json', exclude_none=True) ) - async def handle_authenticated_agent_card(self, request: Request) -> JSONResponse: + async def handle_authenticated_agent_card( + self, request: Request + ) -> JSONResponse: """Hook for per credential agent card response. If a dynamic card is needed based on the credentials provided in the request @@ -183,48 +177,49 @@ async def handle_authenticated_agent_card(self, request: Request) -> JSONRespons """ if not self.agent_card.supportsAuthenticatedExtendedCard: return JSONResponse( - '{"detail": "Authenticated card not supported"}', status_code=404 + '{"detail": "Authenticated card not supported"}', + status_code=404, ) return JSONResponse( self.agent_card.model_dump(mode='json', exclude_none=True) ) - def routes(self) -> dict[Tuple[str, str], Callable[[Request],Any]]: + def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: routes = { ('/v1/message:send', 'POST'): functools.partial( - self._handle_request, - self.handler.on_message_send + self._handle_request, self.handler.on_message_send ), ('/v1/message:stream', 'POST'): functools.partial( self._handle_streaming_request, - self.handler.on_message_send_stream + self.handler.on_message_send_stream, ), ('/v1/tasks/{id}:subscribe', 'POST'): functools.partial( self._handle_streaming_request, - self.handler.on_resubscribe_to_task + self.handler.on_resubscribe_to_task, ), ('/v1/tasks/{id}', 'GET'): functools.partial( - self._handle_request, - self.handler.on_get_task + self._handle_request, self.handler.on_get_task ), - ('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): - functools.partial( - self._handle_request, - self.handler.get_push_notification + ( + '/v1/tasks/{id}/pushNotificationConfigs/{push_id}', + 'GET', + ): functools.partial( + self._handle_request, self.handler.get_push_notification ), - ('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): - functools.partial( - self._handle_request, - self.handler.set_push_notification + ( + '/v1/tasks/{id}/pushNotificationConfigs', + 'POST', + ): functools.partial( + self._handle_request, self.handler.set_push_notification ), - ('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): - functools.partial( - self._handle_request, - self.handler.list_push_notifications + ( + '/v1/tasks/{id}/pushNotificationConfigs', + 'GET', + ): functools.partial( + self._handle_request, self.handler.list_push_notifications ), ('/v1/tasks', 'GET'): functools.partial( - self._handle_request, - self.handler.list_tasks + self._handle_request, self.handler.list_tasks ), } if self.agent_card.supportsAuthenticatedExtendedCard: diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 82036851..54a1617f 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -7,7 +7,6 @@ ) from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.server.request_handlers.rest_handler import RESTHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, prepare_response_object, @@ -43,8 +42,8 @@ def __init__(self, *args, **kwargs): 'GrpcHandler', 'JSONRPCHandler', 'RESTHandler', - 'RequestHandler', 'RESTHandler', + 'RequestHandler', 'build_error_response', 'prepare_response_object', ] diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 17051a68..b226c710 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -1,33 +1,28 @@ import logging from collections.abc import AsyncIterable + +from google.protobuf.json_format import MessageToJson, Parse from starlette.requests import Request -from pydantic import BaseModel, Field, RootModel +from a2a.grpc import a2a_pb2 from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import ( A2AError, AgentCard, + GetTaskPushNotificationConfigParams, InternalError, - Message, Task, - TaskArtifactUpdateEvent, - TaskNotFoundError, - TaskPushNotificationConfig, - TaskStatusUpdateEvent, - GetTaskPushNotificationConfigParams, - MessageSendParams, TaskIdParams, + TaskNotFoundError, TaskPushNotificationConfig, TaskQueryParams, ) +from a2a.utils import proto_utils from a2a.utils.errors import ServerError from a2a.utils.helpers import validate from a2a.utils.telemetry import SpanKind, trace_class -from a2a.grpc import a2a_pb2 -from a2a.utils import proto_utils -from google.protobuf.json_format import Parse, MessageToJson logger = logging.getLogger(__name__) @@ -87,11 +82,11 @@ async def on_message_send( task_or_message = await self.request_handler.on_message_send( a2a_request, context ) - return MessageToJson(proto_utils.ToProto.task_or_message(task_or_message)) + return MessageToJson( + proto_utils.ToProto.task_or_message(task_or_message) + ) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e + raise A2AError(error=e.error if e.error else InternalError()) from e @validate( lambda self: self.agent_card.capabilities.streaming, @@ -130,9 +125,7 @@ async def on_message_send_stream( response = proto_utils.ToProto.stream_response(event) yield MessageToJson(response) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e + raise A2AError(error=e.error if e.error else InternalError()) from e return async def on_cancel_task( @@ -192,12 +185,11 @@ async def on_resubscribe_to_task( async for event in self.request_handler.on_resubscribe_to_task( TaskIdParams(id=task_id), context ): - yield(MessageToJson(proto_utils.ToProto.stream_response(event))) + yield ( + MessageToJson(proto_utils.ToProto.stream_response(event)) + ) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e - + raise A2AError(error=e.error if e.error else InternalError()) from e async def get_push_notification( self, @@ -218,19 +210,22 @@ async def get_push_notification( try: task_id = request.path_params['id'] push_id = request.path_params['push_id'] - if push_id: params = GetTaskPushNotificationConfigParams(id=task_id, push_id=push_id) + if push_id: + params = GetTaskPushNotificationConfigParams( + id=task_id, push_id=push_id + ) else: params = TaskIdParams(id=task_id) - config = await self.request_handler.on_get_task_push_notification_config( - params, context + config = ( + await self.request_handler.on_get_task_push_notification_config( + params, context + ) ) return MessageToJson( proto_utils.ToProto.task_push_notification_config(config) ) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) + raise A2AError(error=e.error if e.error else InternalError()) @validate( lambda self: self.agent_card.capabilities.pushNotifications, @@ -266,16 +261,16 @@ async def set_push_notification( a2a_request = proto_utils.FromProto.task_push_notification_config( params, ) - config = await self.request_handler.on_set_task_push_notification_config( - a2a_request, context + config = ( + await self.request_handler.on_set_task_push_notification_config( + a2a_request, context + ) ) return MessageToJson( proto_utils.ToProto.task_push_notification_config(config) ) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e + raise A2AError(error=e.error if e.error else InternalError()) from e async def on_get_task( self, @@ -305,9 +300,7 @@ async def on_get_task( return MessageToJson(proto_utils.ToProto.task(task)) raise ServerError(error=TaskNotFoundError()) except ServerError as e: - raise A2AError( - error=e.error if e.error else InternalError() - ) from e + raise A2AError(error=e.error if e.error else InternalError()) from e async def list_push_notifications( self, diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index e3d6fb6a..f47881a0 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -30,10 +30,12 @@ ) from a2a.utils.transports import Transports + __all__ = [ 'AGENT_CARD_WELL_KNOWN_PATH', 'DEFAULT_RPC_URL', 'EXTENDED_AGENT_CARD_PATH', + 'Transports', 'append_artifact_to_task', 'are_modalities_compatible', 'build_text_artifact', @@ -49,5 +51,4 @@ 'new_data_artifact', 'new_task', 'new_text_artifact', - 'Transports', ] diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 7012f221..933968c8 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -290,7 +290,9 @@ def agent_card( protocol_version=card.protocol_version, additional_interfaces=[ cls.agent_interface(x) for x in card.additional_interfaces - ] if card.additional_interfaces else None, + ] + if card.additional_interfaces + else None, ) @classmethod @@ -682,7 +684,9 @@ def agent_card( protocol_version=card.protocol_version, additional_interfaces=[ cls.agent_interface(x) for x in card.additional_interfaces - ] if card.additional_interfaces else None, + ] + if card.additional_interfaces + else None, ) @classmethod @@ -827,10 +831,12 @@ def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: def stream_response( cls, response: a2a_pb2.StreamResponse, - ) -> (types.Message - | types.Task - | types.TaskStatusUpdateEvent - | types.TaskArtifactUpdateEvent): + ) -> ( + types.Message + | types.Task + | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent + ): if response.HasField('msg'): return cls.message(response.msg) if response.HasField('task'): diff --git a/src/a2a/utils/transports.py b/src/a2a/utils/transports.py index 33a8f9ed..50f8aa07 100644 --- a/src/a2a/utils/transports.py +++ b/src/a2a/utils/transports.py @@ -1,7 +1,9 @@ """Defines standard protocol transport labels.""" + from enum import Enum + class Transports(str, Enum): - GRPC = "GRPC" - JSONRPC = "JSONRPC" - RESTful = "HTTP+JSON" + GRPC = 'GRPC' + JSONRPC = 'JSONRPC' + RESTful = 'HTTP+JSON' From 4ed965ffb9e80cd87e9f3b557e0782a201a10c19 Mon Sep 17 00:00:00 2001 From: pstephengoogle Date: Thu, 24 Jul 2025 10:46:27 -0600 Subject: [PATCH 15/17] Update src/a2a/client/errors.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/client/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/a2a/client/errors.py b/src/a2a/client/errors.py index 1d49eb55..d6c43256 100644 --- a/src/a2a/client/errors.py +++ b/src/a2a/client/errors.py @@ -47,7 +47,7 @@ def __init__(self, message: str): class A2AClientInvalidArgsError(A2AClientError): - """Client exception for timeout errors during a request.""" + """Client exception for invalid arguments passed to a method.""" def __init__(self, message: str): """Initializes the A2AClientInvalidArgsError. From a1a782c556e8e081ae548a76267880e7c93fef5a Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 24 Jul 2025 16:51:55 +0000 Subject: [PATCH 16/17] Updates for bugs --- src/a2a/client/client_task_manager.py | 10 +--------- src/a2a/client/grpc_client.py | 1 + src/a2a/server/request_handlers/__init__.py | 1 - 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/a2a/client/client_task_manager.py b/src/a2a/client/client_task_manager.py index acccf313..62238593 100644 --- a/src/a2a/client/client_task_manager.py +++ b/src/a2a/client/client_task_manager.py @@ -26,15 +26,7 @@ class ClientTaskManager: def __init__( self, ): - """Initializes the `TaskManager`. - - Args: - task_id: The ID of the task, if known from the request. - context_id: The ID of the context, if known from the request. - task_store: The `TaskStore` instance for persistence. - initial_message: The `Message` that initiated the task, if any. - Used when creating a new task object. - """ + """Initializes the `ClientTaskManager`.""" self._current_task: Task | None = None self._task_id: str | None = None self._context_id: str | None = None diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py index 5a002a2d..d13b7649 100644 --- a/src/a2a/client/grpc_client.py +++ b/src/a2a/client/grpc_client.py @@ -28,6 +28,7 @@ GetTaskPushNotificationConfigParams, Message, MessageSendParams, + MessageSendConfiguration, Task, TaskArtifactUpdateEvent, TaskIdParams, diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 54a1617f..43ebc8e2 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -42,7 +42,6 @@ def __init__(self, *args, **kwargs): 'GrpcHandler', 'JSONRPCHandler', 'RESTHandler', - 'RESTHandler', 'RequestHandler', 'build_error_response', 'prepare_response_object', From f11246a87a228529947b9aaa9c4babbc1f3613a8 Mon Sep 17 00:00:00 2001 From: Phil Stephens Date: Thu, 24 Jul 2025 17:21:35 +0000 Subject: [PATCH 17/17] Fix more issues found by Gemini --- src/a2a/client/jsonrpc_client.py | 22 +++++++++++++++++++--- src/a2a/client/rest_client.py | 16 +++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py index f40693ac..0b829f0d 100644 --- a/src/a2a/client/jsonrpc_client.py +++ b/src/a2a/client/jsonrpc_client.py @@ -11,7 +11,11 @@ from a2a.client.client import A2ACardResolver, Client, ClientConfig, Consumer from a2a.client.client_task_manager import ClientTaskManager -from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, + A2AClientTimeoutError, +) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.types import ( AgentCard, @@ -271,6 +275,8 @@ async def _send_request( ) response.raise_for_status() return response.json() + except httpx.ReadTimeout as e: + raise A2AClientTimeoutError('Client Request timed out') from e except httpx.HTTPStatusError as e: raise A2AClientHTTPError(e.response.status_code, str(e)) from e except json.JSONDecodeError as e: @@ -525,7 +531,7 @@ async def get_card( ) response_data = await self._send_request(payload, modified_kwargs) card = AgentCard.model_validate(response_data) - self.card = card + self.agent_card = card self._needs_extended_card = False return card @@ -568,12 +574,21 @@ async def send_message( *, context: ClientCallContext | None = None, ) -> AsyncIterator[Task | Message]: - # TODO: Set the request params from config + config = MessageSendConfiguration( + accepted_output_modes=self._config.accepted_output_modes, + blocking=not self._config.polling, + push_notification_config=( + self._config.push_notification_configs[0] + if self._config.push_notification_configs + else None + ), + ) if not self._config.streaming or not self._card.capabilities.streaming: response = await self._transport_client.send_message( SendMessageRequest( params=MessageSendParams( message=request, + configuration=config, ), id=str(uuid4()), ), @@ -592,6 +607,7 @@ async def send_message( SendStreamingMessageRequest( params=MessageSendParams( message=request, + configuration=config, ), id=str(uuid4()), ), diff --git a/src/a2a/client/rest_client.py b/src/a2a/client/rest_client.py index 140a8b01..6e9f0b9b 100644 --- a/src/a2a/client/rest_client.py +++ b/src/a2a/client/rest_client.py @@ -489,7 +489,7 @@ async def resubscribe( async with aconnect_sse( self.httpx_client, 'POST', - f'{self.url}/v1/tasks/{request.taskId}:subscribe', + f'{self.url}/v1/tasks/{request.id}:subscribe', json=payload, **modified_kwargs, ) as event_source: @@ -551,7 +551,7 @@ async def get_card( '/v1/card/get', {}, modified_kwargs ) card = AgentCard.model_validate(response_data) - self.card = card + self.agent_card = card self._needs_extended_card = False return card @@ -594,11 +594,20 @@ async def send_message( *, context: ClientCallContext | None = None, ) -> AsyncIterator[Task | Message]: - # TODO: Set the request params from config + config = MessageSendConfiguration( + accepted_output_modes=self._config.accepted_output_modes, + blocking=not self._config.polling, + push_notification_config=( + self._config.push_notification_configs[0] + if self._config.push_notification_configs + else None + ), + ) if not self._config.streaming or not self._card.capabilities.streaming: response = await self._transport_client.send_message( MessageSendParams( message=request, + configuration=config, ), http_kwargs=self.get_http_args(context), context=context, @@ -613,6 +622,7 @@ async def send_message( async for event in self._transport_client.send_message_streaming( MessageSendParams( message=request, + configuration=config, ), http_kwargs=self.get_http_args(context), context=context,