diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 97d884a2..39370e0f 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -17,23 +17,6 @@ AServers AService AStarlette AUser -DSNs -EUR -GBP -GVsb -INR -JPY -JSONRPCt -Llm -POSTGRES -RUF -Tful -aconnect -adk -agentic -aio -aiomysql -aproject autouse backticks cla @@ -82,6 +65,7 @@ pyi pypistats pyupgrade pyversions +redef respx resub RUF @@ -91,5 +75,6 @@ sse tagwords taskupdate testuuid +Tful typeerror vulnz diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 33200ad1..40a326a4 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -7,16 +7,47 @@ CredentialService, InMemoryContextCredentialStore, ) -from a2a.client.client import A2ACardResolver, A2AClient +from a2a.client.client import ( + A2ACardResolver, + Client, + ClientConfig, + ClientEvent, + Consumer, +) +from a2a.client.client_factory import ( + ClientFactory, + ClientProducer, + minimal_agent_card, +) from a2a.client.errors import ( A2AClientError, A2AClientHTTPError, A2AClientJSONError, A2AClientTimeoutError, ) +from a2a.client.grpc_client import ( + GrpcClient, + GrpcTransportClient, + NewGrpcClient, +) from a2a.client.helpers import create_text_message_object +from a2a.client.jsonrpc_client import ( + JsonRpcClient, + JsonRpcTransportClient, + NewJsonRpcClient, +) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.rest_client import ( + NewRestfulClient, + RestClient, + RestTransportClient, +) + +# For backward compatability define this alias. This will be deprecated in +# a future release. +A2AClient = JsonRpcTransportClient +A2AGrpcClient = GrpcTransportClient logger = logging.getLogger(__name__) @@ -41,16 +72,32 @@ def __init__(self, *args, **kwargs): __all__ = [ 'A2ACardResolver', - 'A2AClient', + 'A2AClient', # for backward compatability 'A2AClientError', 'A2AClientHTTPError', 'A2AClientJSONError', 'A2AClientTimeoutError', - 'A2AGrpcClient', + 'A2AGrpcClient', # for backward compatability 'AuthInterceptor', + 'Client', 'ClientCallContext', 'ClientCallInterceptor', + 'ClientConfig', + 'ClientEvent', + 'ClientFactory', + 'ClientProducer', + 'Consumer', 'CredentialService', + 'GrpcClient', + 'GrpcTransportClient', 'InMemoryContextCredentialStore', + 'JsonRpcClient', + 'JsonRpcTransportClient', + 'NewGrpcClient', + 'NewJsonRpcClient', + 'NewRestfulClient', + 'RestClient', + 'RestTransportClient', 'create_text_message_object', + 'minimal_agent_card', ] diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 66dfe0a4..450e42ab 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -1,40 +1,47 @@ +import dataclasses import json import logging -from collections.abc import AsyncGenerator +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Callable, Coroutine from typing import Any -from uuid import uuid4 import httpx -from httpx_sse import SSEError, aconnect_sse from pydantic import ValidationError + +# Attempt to import the optional module +try: + from grpc.aio import Channel +except ImportError: + # If grpc.aio is not available, define a dummy type for type checking. + # This dummy type will only be used by type checkers. + if TYPE_CHECKING: + + class Channel: # type: ignore[no-redef] + pass + else: + Channel = None # At runtime, pd will be None if the import failed. + from a2a.client.errors import ( A2AClientHTTPError, A2AClientJSONError, - A2AClientTimeoutError, ) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.types import ( AgentCard, - CancelTaskRequest, - CancelTaskResponse, - GetTaskPushNotificationConfigRequest, - GetTaskPushNotificationConfigResponse, - GetTaskRequest, - GetTaskResponse, - SendMessageRequest, - SendMessageResponse, - SendStreamingMessageRequest, - SendStreamingMessageResponse, - SetTaskPushNotificationConfigRequest, - SetTaskPushNotificationConfigResponse, -) -from a2a.utils.constants import ( - AGENT_CARD_WELL_KNOWN_PATH, + GetTaskPushNotificationConfigParams, + Message, + PushNotificationConfig, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, ) -from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH logger = logging.getLogger(__name__) @@ -128,373 +135,147 @@ async def get_agent_card( return agent_card -@trace_class(kind=SpanKind.CLIENT) -class A2AClient: - """A2A Client for interacting with an A2A agent.""" +@dataclasses.dataclass +class ClientConfig: + """Configuration class for the A2A Client Factory""" - def __init__( - self, - httpx_client: httpx.AsyncClient, - agent_card: AgentCard | None = None, - url: str | None = None, - interceptors: list[ClientCallInterceptor] | None = None, - ): - """Initializes the A2AClient. + streaming: bool = True + """Whether client supports streaming""" - Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. + polling: bool = False + """Whether client prefers to poll for updates from message:send. It is + the callers job to check if the response is completed and if not run a + polling loop.""" - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. - url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. - interceptors: An optional list of client call interceptors to apply to requests. + httpx_client: httpx.AsyncClient | None = None + """Http client to use to connect to agent.""" - Raises: - ValueError: If neither `agent_card` nor `url` is provided. - """ - if agent_card: - self.url = agent_card.url - elif url: - self.url = url - else: - raise ValueError('Must provide either agent_card or url') + grpc_channel_factory: Callable[[str], Channel] | None = None + """Generates a grpc connection channel for a given url.""" - self.httpx_client = httpx_client - self.agent_card = agent_card - self.interceptors = interceptors or [] + supported_transports: list[str] = dataclasses.field(default_factory=list) + """Ordered list of transports for connecting to agent + (in order of preference). Empty implies JSONRPC only. - async def _apply_interceptors( - self, - method_name: str, - request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None, - context: ClientCallContext | None, - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Applies all registered interceptors to the request.""" - final_http_kwargs = http_kwargs or {} - final_request_payload = request_payload - - for interceptor in self.interceptors: - ( - final_request_payload, - final_http_kwargs, - ) = await interceptor.intercept( - method_name, - final_request_payload, - final_http_kwargs, - self.agent_card, - context, - ) - return final_request_payload, final_http_kwargs + This is a string type and not a Transports enum type to allow custom + transports to exist in closed ecosystems. + """ - @staticmethod - async def get_client_from_agent_card_url( - httpx_client: httpx.AsyncClient, - base_url: str, - agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, - http_kwargs: dict[str, Any] | None = None, - ) -> 'A2AClient': - """Fetches the public AgentCard and initializes an A2A client. - - This method will always fetch the public agent card. If an authenticated - or extended agent card is required, the A2ACardResolver should be used - directly to fetch the specific card, and then the A2AClient should be - instantiated with it. + use_client_preference: bool = False + """Whether to use client transport preferences over server preferences. + Recommended to use server preferences in most situations.""" - Args: - httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). - base_url: The base URL of the agent's host. - agent_card_path: The path to the agent card endpoint, relative to the base URL. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.get request when fetching the agent card. + accepted_outputModes: list[str] = dataclasses.field(default_factory=list) + """The set of accepted output modes for the client.""" - Returns: - An initialized `A2AClient` instance. + push_notification_configs: list[PushNotificationConfig] = dataclasses.field( + default_factory=list + ) + """Push notification callbacks to use for every request.""" - Raises: - A2AClientHTTPError: If an HTTP error occurs fetching the agent card. - A2AClientJSONError: If the agent card response is invalid. - """ - agent_card: AgentCard = await A2ACardResolver( - httpx_client, base_url=base_url, agent_card_path=agent_card_path - ).get_agent_card( - http_kwargs=http_kwargs - ) # Fetches public card by default - return A2AClient(httpx_client=httpx_client, agent_card=agent_card) - async def send_message( - self, - request: SendMessageRequest, - *, - http_kwargs: dict[str, Any] | None = None, - context: ClientCallContext | None = None, - ) -> SendMessageResponse: - """Sends a non-streaming message request to the agent. +UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None +# Alias for emitted events from client +ClientEvent = tuple[Task, UpdateEvent] +# Alias for an event consuming callback. It takes either a (task, update) pair +# or a message as well as the agent card for the agent this came from. +Consumer = Callable[ + [ClientEvent | Message, AgentCard], Coroutine[None, Any, Any] +] - Args: - request: The `SendMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - Returns: - A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. +class Client(ABC): + def __init__( + self, + consumers: list[Consumer] = [], + middleware: list[ClientCallInterceptor] = [], + ): + self._consumers = consumers or [] + self._middleware = middleware or [] - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/send', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SendMessageResponse.model_validate(response_data) - - async def send_message_streaming( + @abstractmethod + async def send_message( self, - request: SendStreamingMessageRequest, + request: Message, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> AsyncGenerator[SendStreamingMessageResponse]: - """Sends a streaming message request to the agent and yields responses as they arrive. - - This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. - - Args: - request: The `SendStreamingMessageRequest` object containing the message and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. A default `timeout=None` is set but can be overridden. - context: The client call context. - - Yields: - `SendStreamingMessageResponse` objects as they are received in the SSE stream. - These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. - - Raises: - A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. - A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'message/stream', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - - modified_kwargs.setdefault('timeout', None) - - async with aconnect_sse( - self.httpx_client, - 'POST', - self.url, - json=payload, - **modified_kwargs, - ) as event_source: - try: - async for sse in event_source.aiter_sse(): - yield SendStreamingMessageResponse.model_validate( - json.loads(sse.data) - ) - except SSEError as e: - raise A2AClientHTTPError( - 400, - f'Invalid SSE response or protocol error: {e}', - ) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e - - async def _send_request( - self, - rpc_request_payload: dict[str, Any], - http_kwargs: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Sends a non-streaming JSON-RPC request to the agent. - - Args: - rpc_request_payload: JSON RPC payload for sending the request. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - - Returns: - The JSON response payload as a dictionary. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON. + ) -> AsyncIterator[ClientEvent | Message]: + """Sends a message to the server. + + This will automatically use the streaming or non-streaming approach + as supported by the server and the client config. Client will + aggregate update events and return an iterator of (`Task`,`Update`) + pairs, or a `Message`. Client will also send these values to any + configured `Consumer`s in the client. """ - try: - response = await self.httpx_client.post( - self.url, json=rpc_request_payload, **(http_kwargs or {}) - ) - response.raise_for_status() - return response.json() - except httpx.ReadTimeout as e: - raise A2AClientTimeoutError('Client Request timed out') from e - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e - except httpx.RequestError as e: - raise A2AClientHTTPError( - 503, f'Network communication error: {e}' - ) from e + yield + @abstractmethod async def get_task( self, - request: GetTaskRequest, + request: TaskQueryParams, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> GetTaskResponse: - """Retrieves the current state and history of a specific task. - - Args: - request: The `GetTaskRequest` object specifying the task ID and history length. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `GetTaskResponse` object containing the Task or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskResponse.model_validate(response_data) + ) -> Task: + pass + @abstractmethod async def cancel_task( self, - request: CancelTaskRequest, + request: TaskIdParams, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> CancelTaskResponse: - """Requests the agent to cancel a specific task. - - Args: - request: The `CancelTaskRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `CancelTaskResponse` object containing the updated Task with canceled status or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/cancel', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return CancelTaskResponse.model_validate(response_data) + ) -> Task: + pass + @abstractmethod async def set_task_callback( self, - request: SetTaskPushNotificationConfigRequest, + request: TaskPushNotificationConfig, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> SetTaskPushNotificationConfigResponse: - """Sets or updates the push notification configuration for a specific task. - - Args: - request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. - - Returns: - A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. - - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/set', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return SetTaskPushNotificationConfigResponse.model_validate( - response_data - ) + ) -> TaskPushNotificationConfig: + pass + @abstractmethod async def get_task_callback( self, - request: GetTaskPushNotificationConfigRequest, + request: GetTaskPushNotificationConfigParams, *, - http_kwargs: dict[str, Any] | None = None, context: ClientCallContext | None = None, - ) -> GetTaskPushNotificationConfigResponse: - """Retrieves the push notification configuration for a specific task. + ) -> TaskPushNotificationConfig: + pass - Args: - request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. - http_kwargs: Optional dictionary of keyword arguments to pass to the - underlying httpx.post request. - context: The client call context. + @abstractmethod + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + yield - Returns: - A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. + @abstractmethod + async def get_card( + self, *, context: ClientCallContext | None = None + ) -> AgentCard: + pass - Raises: - A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON or validated. - """ - if not request.id: - request.id = str(uuid4()) - - # Apply interceptors before sending - payload, modified_kwargs = await self._apply_interceptors( - 'tasks/pushNotificationConfig/get', - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - context, - ) - response_data = await self._send_request(payload, modified_kwargs) - return GetTaskPushNotificationConfigResponse.model_validate( - response_data - ) + async def add_event_consumer(self, consumer: Consumer): + """Attaches additional consumers to the `Client`""" + self._consumers.append(consumer) + + async def add_request_middleware(self, middleware: ClientCallInterceptor): + """Attaches additional middleware to the `Client`""" + self._middleware.append(middleware) + + async def consume( + self, + event: tuple[Task, UpdateEvent] | Message | None, + card: AgentCard, + ): + """Processes the event via all the registered `Consumer`s.""" + if not event: + return + for c in self._consumers: + await c(event, card) diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py new file mode 100644 index 00000000..2dd37546 --- /dev/null +++ b/src/a2a/client/client_factory.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import logging + +from collections.abc import Callable + +from a2a.client.client import Client, ClientConfig, Consumer +from a2a.client.grpc_client import NewGrpcClient +from a2a.client.jsonrpc_client import NewJsonRpcClient +from a2a.client.middleware import ClientCallInterceptor +from a2a.client.rest_client import NewRestfulClient +from a2a.types import ( + AgentCapabilities, + AgentCard, +) +from a2a.utils import Transports + + +logger = logging.getLogger(__name__) + +ClientProducer = Callable[ + [ + AgentCard | str, + ClientConfig, + list[Consumer], + list[ClientCallInterceptor], + ], + Client, +] + + +class ClientFactory: + """ClientFactory is used to generate the appropriate client for the agent. + + The factory is configured with a `ClientConfig` and optionally a list of + `Consumer`s to use for all generated `Client`s. The expected use is: + + factory = ClientFactory(config, consumers) + # Optionally register custom client implementations + factory.register('my_customer_transport', NewCustomTransportClient) + # Then with an agent card make a client with additional consumers and + # interceptors + client = factory.create(card, additional_consumers, interceptors) + # Now the client can be used the same regardless of transport and + # aligns client config with server capabilities. + """ + + def __init__( + self, + config: ClientConfig, + consumers: list[Consumer] = [], + ): + self._config = config + self._consumers = consumers + self._registry: dict[str, ClientProducer] = {} + # By default register the 3 core transports if in the config. + # Can be overridden with custom clients via the register method. + if Transports.JSONRPC in self._config.supported_transports: + self._registry[Transports.JSONRPC] = NewJsonRpcClient + if Transports.RESTful in self._config.supported_transports: + self._registry[Transports.RESTful] = NewRestfulClient + if Transports.GRPC in self._config.supported_transports: + self._registry[Transports.GRPC] = NewGrpcClient + + def register(self, label: str, generator: ClientProducer): + """Register a new client producer for a given transport label.""" + self._registry[label] = generator + + def create( + self, + card: AgentCard, + consumers: list[Consumer] | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ) -> Client: + """Create a new `Client` for the provided `AgentCard`. + + Args: + card: An `AgentCard` defining the characteristics of the agent. + consumers: A list of `Consumer` methods to pass responses to. + interceptors: A list of interceptors to use for each request. These + are used for things like attaching credentials or http headers + to all outbound requests. + + Returns: + A `Client` object. + + Raises: + If there is no valid matching of the client configuration with the + server configuration, a `ValueError` is raised. + """ + # Determine preferential transport + server_set = [card.preferred_transport or 'JSONRPC'] + if card.additional_interfaces: + server_set.extend([x.transport for x in card.additional_interfaces]) + client_set = self._config.supported_transports or ['JSONRPC'] + transport = None + # Two options, use the client ordering or the server ordering. + if self._config.use_client_preference: + for x in client_set: + if x in server_set: + transport = x + break + else: + for x in server_set: + if x in client_set: + transport = x + break + if not transport: + raise ValueError('no compatible transports found.') + if transport not in self._registry: + raise ValueError(f'no client available for {transport}') + all_consumers = self._consumers.copy() + if consumers: + all_consumers.extend(consumers) + return self._registry[transport]( + card, self._config, all_consumers, interceptors + ) + + +def minimal_agent_card(url: str, transports: list[str] = []) -> AgentCard: + """Generates a minimal card to simplify bootstrapping client creation. + + This minimal card is not viable itself to interact with the remote agent. + Instead this is a short hand way to take a known url and transport option + and interact with the get card endpoint of the agent server to get the + correct agent card. This pattern is necessary for gRPC based card access + as typically these servers won't expose a well known path card. + """ + return AgentCard( + url=url, + preferred_transport=transports[0] if transports else None, + additional_interfaces=transports[1:] if len(transports) > 1 else [], + supports_authenticated_extended_card=True, + capabilities=AgentCapabilities(), + default_input_modes=[], + default_output_modes=[], + description='', + skills=[], + version='', + name='', + ) diff --git a/src/a2a/client/client_task_manager.py b/src/a2a/client/client_task_manager.py new file mode 100644 index 00000000..62238593 --- /dev/null +++ b/src/a2a/client/client_task_manager.py @@ -0,0 +1,171 @@ +import logging + +from a2a.client.errors import A2AClientInvalidArgsError +from a2a.server.events.event_queue import Event +from a2a.types import ( + Message, + Task, + TaskArtifactUpdateEvent, + TaskState, + TaskStatus, + TaskStatusUpdateEvent, +) +from a2a.utils import append_artifact_to_task + + +logger = logging.getLogger(__name__) + + +class ClientTaskManager: + """Helps manage a task's lifecycle during execution of a request. + + Responsible for retrieving, saving, and updating the `Task` object based on + events received from the agent. + """ + + def __init__( + self, + ): + """Initializes the `ClientTaskManager`.""" + self._current_task: Task | None = None + self._task_id: str | None = None + self._context_id: str | None = None + + def get_task(self) -> Task | None: + """Retrieves the current task object, either from memory. + + If `task_id` is set, it returns `_current_task` otherwise None. + + Returns: + The `Task` object if found, otherwise `None`. + """ + if not self._task_id: + logger.debug('task_id is not set, cannot get task.') + return None + + return self._current_task + + async def save_task_event( + self, event: Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ) -> Task | None: + """Processes a task-related event (Task, Status, Artifact) and saves the updated task state. + + Ensures task and context IDs match or are set from the event. + + Args: + event: The task-related event (`Task`, `TaskStatusUpdateEvent`, or `TaskArtifactUpdateEvent`). + + Returns: + The updated `Task` object after processing the event. + + Raises: + ClientError: If the task ID in the event conflicts with the TaskManager's ID + when the TaskManager's ID is already set. + """ + if isinstance(event, Task): + if self._current_task: + raise A2AClientInvalidArgsError( + 'Task is already set, create new manager for new tasks.' + ) + await self._save_task(event) + return event + task_id_from_event = ( + event.id if isinstance(event, Task) else event.taskId + ) + if not self._task_id: + self._task_id = task_id_from_event + if not self._context_id: + self._context_id = event.contextId + + logger.debug( + 'Processing save of task event of type %s for task_id: %s', + type(event).__name__, + task_id_from_event, + ) + + task = self._current_task + if not task: + task = Task( + status=TaskStatus(state=TaskState.unknown), + id=task_id_from_event, + contextId=self._context_id if self._context_id else '', + ) + if isinstance(event, TaskStatusUpdateEvent): + logger.debug( + 'Updating task %s status to: %s', + event.taskId, + event.status.state, + ) + if event.status.message: + if not task.history: + task.history = [event.status.message] + else: + task.history.append(event.status.message) + if event.metadata: + if not task.metadata: + task.metadata = {} + task.metadata.update(event.metadata) + task.status = event.status + else: + logger.debug('Appending artifact to task %s', task.id) + append_artifact_to_task(task, event) + self._current_task = task + return task + + async def process(self, event: Event) -> Event: + """Processes an event, updates the task state if applicable, stores it, and returns the event. + + If the event is task-related (`Task`, `TaskStatusUpdateEvent`, `TaskArtifactUpdateEvent`), + the internal task state is updated and persisted. + + Args: + event: The event object received from the agent. + + Returns: + The same event object that was processed. + """ + if isinstance( + event, Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ): + await self.save_task_event(event) + + return event + + async def _save_task(self, task: Task) -> None: + """Saves the given task to the `_current_task` and updated `_task_id` and `_context_id` + + Args: + task: The `Task` object to save. + """ + logger.debug('Saving task with id: %s', task.id) + self._current_task = task + if not self._task_id: + logger.info('New task created with id: %s', task.id) + self._task_id = task.id + self._context_id = task.contextId + + def update_with_message(self, message: Message, task: Task) -> Task: + """Updates a task object adding a new message to its history. + + If the task has a message in its current status, that message is moved + to the history first. + + Args: + message: The new `Message` to add to the history. + task: The `Task` object to update. + + Returns: + The updated `Task` object (updated in-place). + """ + if task.status.message: + if task.history: + task.history.append(task.status.message) + else: + task.history = [task.status.message] + task.status.message = None + if task.history: + task.history.append(message) + else: + task.history = [message] + self._current_task = task + return task diff --git a/src/a2a/client/errors.py b/src/a2a/client/errors.py index 5fe5512a..d6c43256 100644 --- a/src/a2a/client/errors.py +++ b/src/a2a/client/errors.py @@ -44,3 +44,16 @@ def __init__(self, message: str): """ self.message = message super().__init__(f'Timeout Error: {message}') + + +class A2AClientInvalidArgsError(A2AClientError): + """Client exception for invalid arguments passed to a method.""" + + def __init__(self, message: str): + """Initializes the A2AClientInvalidArgsError. + + Args: + message: A descriptive error message. + """ + self.message = message + super().__init__(f'Invalid arguments error: {message}') diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py index d224b201..d13b7649 100644 --- a/src/a2a/client/grpc_client.py +++ b/src/a2a/client/grpc_client.py @@ -1,6 +1,6 @@ import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator try: @@ -13,11 +13,22 @@ ) from e +from a2a.client.client import ( + Client, + ClientCallContext, + ClientConfig, + ClientEvent, + Consumer, +) +from a2a.client.client_task_manager import ClientTaskManager +from a2a.client.middleware import ClientCallInterceptor from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, + GetTaskPushNotificationConfigParams, Message, MessageSendParams, + MessageSendConfiguration, Task, TaskArtifactUpdateEvent, TaskIdParams, @@ -33,17 +44,17 @@ @trace_class(kind=SpanKind.CLIENT) -class A2AGrpcClient: - """A2A Client for interacting with an A2A agent via gRPC.""" +class GrpcTransportClient: + """Transport specific details for interacting with an A2A agent via gRPC.""" def __init__( self, grpc_stub: a2a_pb2_grpc.A2AServiceStub, - agent_card: AgentCard, + agent_card: AgentCard | None, ): - """Initializes the A2AGrpcClient. + """Initializes the GrpcTransportClient. - Requires an `AgentCard` + Requires an `AgentCard` and a grpc `A2AServiceStub`. Args: grpc_stub: A grpc client stub. @@ -51,10 +62,17 @@ def __init__( """ self.agent_card = agent_card self.stub = grpc_stub + # If they don't provide an agent card, but do have a stub, lookup the + # card from the stub. + self._needs_extended_card = ( + agent_card.supportsAuthenticatedExtendedCard if agent_card else True + ) async def send_message( self, request: MessageSendParams, + *, + context: ClientCallContext | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent. @@ -80,6 +98,8 @@ async def send_message( async def send_message_streaming( self, request: MessageSendParams, + *, + context: ClientCallContext | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -125,6 +145,8 @@ async def send_message_streaming( async def get_task( self, request: TaskQueryParams, + *, + context: ClientCallContext | None = None, ) -> Task: """Retrieves the current state and history of a specific task. @@ -142,6 +164,8 @@ async def get_task( async def cancel_task( self, request: TaskIdParams, + *, + context: ClientCallContext | None = None, ) -> Task: """Requests the agent to cancel a specific task. @@ -159,6 +183,8 @@ async def cancel_task( async def set_task_callback( self, request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task. @@ -182,6 +208,8 @@ async def set_task_callback( async def get_task_callback( self, request: TaskIdParams, # TODO: Update to a push id params + *, + context: ClientCallContext | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task. @@ -197,3 +225,192 @@ async def get_task_callback( ) ) return proto_utils.FromProto.task_push_notification_config(config) + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the authenticated card (if necessary) or the public one. + + Args: + context: The client call context. + + Returns: + A `AgentCard` object containing the card. + + Raises: + grpc.RpcError: If a gRPC error occurs during the request. + """ + # If we don't have the public card, try to get that first. + card = self.agent_card + + if not self._needs_extended_card: + return card + + card_pb = await self.stub.GetAgentCard( + a2a_pb2.GetAgentCardRequest(), + ) + card = proto_utils.FromProto.agent_card(card_pb) + self.agent_card = card + self._needs_extended_card = False + return card + + +class GrpcClient(Client): + """GrpcClient provides the Client interface for the gRPC transport.""" + + def __init__( + self, + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], + ): + super().__init__(consumers, middleware) + if not config.grpc_channel_factory: + raise Exception('GRPC client requires channel factory.') + self._card = card + self._config = config + # Defer init to first use. + self._transport_client = None + channel = self._config.grpc_channel_factory(self._card.url) + stub = a2a_pb2_grpc.A2AServiceStub(channel) + self._transport_client = GrpcTransportClient(stub, self._card) + + async def send_message( + self, + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[ClientEvent | Message]: + config = MessageSendConfiguration( + accepted_output_modes=self._config.accepted_output_modes, + blocking=not self._config.polling, + push_notification_config=( + self._config.push_notification_configs[0] + if self._config.push_notification_configs + else None + ), + ) + if not self._config.streaming or not self._card.capabilities.streaming: + response = await self._transport_client.send_message( + MessageSendParams( + message=request, + configuration=config, + ), + context=context, + ) + result = ( + (response, None) if isinstance(response, Task) else response + ) + await self.consume(result, self._card) + yield result + return + # Get Task tracker + tracker = ClientTaskManager() + async for event in self._transport_client.send_message_streaming( + MessageSendParams( + message=request, + configuration=config, + ), + context=context, + ): + # Update task, check for errors, etc. + if isinstance(event, Message): + await self.consume(event, self._card) + yield event + return + await tracker.process(event) + result = ( + tracker.get_task(), + None if isinstance(event, Task) else event, + ) + await self.consume(result, self._card) + yield result + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.get_task( + request, + context=context, + ) + return response + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.cancel_task( + request, + context=context, + ) + return response + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.set_task_callback( + request, + context=context, + ) + return response + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.get_task_callback( + request, + context=context, + ) + return response + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + if not self._config.streaming or not self._card.capabilities.streaming: + raise Exception( + 'client and/or server do not support resubscription.' + ) + async for event in self._transport_client.resubscribe( + request, + context=context, + ): + # Update task, check for errors, etc. + yield event + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + card = await self._transport_client.get_card( + context=context, + ) + self._card = card + return card + + +def NewGrpcClient( + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], +) -> Client: + """Generator for the `GrpcClient` implementation.""" + return GrpcClient(card, config, consumers, middleware) diff --git a/src/a2a/client/jsonrpc_client.py b/src/a2a/client/jsonrpc_client.py new file mode 100644 index 00000000..0b829f0d --- /dev/null +++ b/src/a2a/client/jsonrpc_client.py @@ -0,0 +1,735 @@ +import json +import logging + +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any +from uuid import uuid4 + +import httpx + +from httpx_sse import SSEError, aconnect_sse + +from a2a.client.client import A2ACardResolver, Client, ClientConfig, Consumer +from a2a.client.client_task_manager import ClientTaskManager +from a2a.client.errors import ( + A2AClientHTTPError, + A2AClientJSONError, + A2AClientTimeoutError, +) +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.types import ( + AgentCard, + CancelTaskRequest, + CancelTaskResponse, + GetTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskRequest, + GetTaskResponse, + JSONRPCErrorResponse, + Message, + MessageSendParams, + SendMessageRequest, + SendMessageResponse, + SendStreamingMessageRequest, + SendStreamingMessageResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + Task, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskResubscriptionRequest, +) +from a2a.utils.constants import ( + AGENT_CARD_WELL_KNOWN_PATH, +) +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class JsonRpcTransportClient: + """A2A Client for interacting with an A2A agent.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the A2AClient. + + Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. + + Args: + httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). + agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. + url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. + interceptors: An optional list of client call interceptors to apply to requests. + + Raises: + ValueError: If neither `agent_card` nor `url` is provided. + """ + if agent_card: + self.url = agent_card.url + elif url: + self.url = url + else: + raise ValueError('Must provide either agent_card or url') + + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + # Indicate if we have captured an extended card details so we can update + # on first call if needed. It is done this way so the caller can setup + # their auth credentials based on the public card and get the updated + # card. + self._needs_extended_card = ( + not agent_card.supportsAuthenticatedExtendedCard + if agent_card + else True + ) + + async def _apply_interceptors( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Applies all registered interceptors to the request.""" + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + + for interceptor in self.interceptors: + ( + final_request_payload, + final_http_kwargs, + ) = await interceptor.intercept( + method_name, + final_request_payload, + final_http_kwargs, + self.agent_card, + context, + ) + return final_request_payload, final_http_kwargs + + @staticmethod + async def get_client_from_agent_card_url( + httpx_client: httpx.AsyncClient, + base_url: str, + agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH, + http_kwargs: dict[str, Any] | None = None, + ) -> 'A2AClient': + """[deprecated] Fetches the public AgentCard and initializes an A2A client. + + This method will always fetch the public agent card. If an authenticated + or extended agent card is required, the A2ACardResolver should be used + directly to fetch the specific card, and then the A2AClient should be + instantiated with it. + + Args: + httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). + base_url: The base URL of the agent's host. + agent_card_path: The path to the agent card endpoint, relative to the base URL. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.get request when fetching the agent card. + + Returns: + An initialized `A2AClient` instance. + + Raises: + A2AClientHTTPError: If an HTTP error occurs fetching the agent card. + A2AClientJSONError: If the agent card response is invalid. + """ + agent_card: AgentCard = await A2ACardResolver( + httpx_client, base_url=base_url, agent_card_path=agent_card_path + ).get_agent_card( + http_kwargs=http_kwargs + ) # Fetches public card by default + return A2AClient(httpx_client=httpx_client, agent_card=agent_card) + + async def send_message( + self, + request: SendMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SendMessageResponse: + """Sends a non-streaming message request to the agent. + + Args: + request: The `SendMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'message/send', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return SendMessageResponse.model_validate(response_data) + + async def send_message_streaming( + self, + request: SendStreamingMessageRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `SendStreamingMessageRequest` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'message/stream', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + yield SendStreamingMessageResponse.model_validate( + json.loads(sse.data) + ) + except SSEError as e: + raise A2AClientHTTPError( + 400, + f'Invalid SSE response or protocol error: {e}', + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_request( + self, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Sends a non-streaming JSON-RPC request to the agent. + + Args: + rpc_request_payload: JSON RPC payload for sending the request. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + + Returns: + The JSON response payload as a dictionary. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON. + """ + try: + response = await self.httpx_client.post( + self.url, json=rpc_request_payload, **(http_kwargs or {}) + ) + response.raise_for_status() + return response.json() + except httpx.ReadTimeout as e: + raise A2AClientTimeoutError('Client Request timed out') from e + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_task( + self, + request: GetTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskResponse: + """Retrieves the current state and history of a specific task. + + Args: + request: The `GetTaskRequest` object specifying the task ID and history length. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskResponse` object containing the Task or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/get', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return GetTaskResponse.model_validate(response_data) + + async def cancel_task( + self, + request: CancelTaskRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> CancelTaskResponse: + """Requests the agent to cancel a specific task. + + Args: + request: The `CancelTaskRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `CancelTaskResponse` object containing the updated Task with canceled status or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/cancel', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return CancelTaskResponse.model_validate(response_data) + + async def set_task_callback( + self, + request: SetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> SetTaskPushNotificationConfigResponse: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/set', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return SetTaskPushNotificationConfigResponse.model_validate( + response_data + ) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> GetTaskPushNotificationConfigResponse: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + if not request.id: + request.id = str(uuid4()) + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/get', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return GetTaskPushNotificationConfigResponse.model_validate( + response_data + ) + + async def resubscribe( + self, + request: TaskResubscriptionRequest, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[SendStreamingMessageResponse]: + """Reconnects to get task updates + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `TaskResubscriptionRequest` object containing the task information to reconnect to. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + `SendStreamingMessageResponse` objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/resubscribe', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + self.url, + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + yield SendStreamingMessageResponse.model_validate( + json.loads(sse.data) + ) + except SSEError as e: + raise A2AClientHTTPError( + 400, + f'Invalid SSE response or protocol error: {e}', + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_card( + self, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the authenticated card (if necessary) or the public one. + + Args: + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `AgentCard` object containing the card or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + # If we don't have the public card, try to get that first. + card = self.agent_card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = await resolver.get_agent_card(http_kwargs=http_kwargs) + self._needs_extended_card = ( + card.supports_authenticated_extended_card + ) + self.agent_card = card + + if not self._needs_extended_card: + return card + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'card/getAuthenticated', + '', + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + card = AgentCard.model_validate(response_data) + self.agent_card = card + self._needs_extended_card = False + return card + + +@trace_class(kind=SpanKind.CLIENT) +class JsonRpcClient(Client): + """JsonRpcClient is the implementation of the JSONRPC A2A client. + + This client proxies requests to the JsonRpcTransportClient implementation + and manages the JSONRPC specific details. If passing additional arguments + in the http.post command, these should be attached to the ClientCallContext + under the dictionary key 'http_kwargs'. + """ + + def __init__( + self, + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], + ): + super().__init__(consumers, middleware) + if not config.httpx_client: + raise Exception('JsonRpc client requires httpx client.') + self._card = card + url = card.url + self._config = config + self._transport_client = JsonRpcTransportClient( + config.httpx_client, self._card, url, middleware + ) + + def get_http_args( + self, context: ClientCallContext + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs', None) if context else None + + async def send_message( + self, + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + config = MessageSendConfiguration( + accepted_output_modes=self._config.accepted_output_modes, + blocking=not self._config.polling, + push_notification_config=( + self._config.push_notification_configs[0] + if self._config.push_notification_configs + else None + ), + ) + if not self._config.streaming or not self._card.capabilities.streaming: + response = await self._transport_client.send_message( + SendMessageRequest( + params=MessageSendParams( + message=request, + configuration=config, + ), + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + if isinstance(response.root, JSONRPCErrorResponse): + raise response.root.error + result = response.root.result + result = result if isinstance(result, Message) else (result, None) + await self.consume(result, self._card) + yield result + return + tracker = ClientTaskManager() + async for event in self._transport_client.send_message_streaming( + SendStreamingMessageRequest( + params=MessageSendParams( + message=request, + configuration=config, + ), + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ): + if isinstance(event.root, JSONRPCErrorResponse): + raise event.root.error + result = event.root.result + # Update task, check for errors, etc. + if isinstance(result, Message): + yield result + return + await tracker.process(result) + result = ( + tracker.get_task(), + None if isinstance(result, Task) else result, + ) + await self.consume(result, self._card) + yield result + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.get_task( + GetTaskRequest( + params=request, + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + return response.result + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.cancel_task( + CancelTaskRequest( + params=request, + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + return response.result + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.set_task_callback( + SetTaskPushNotificationConfigRequest( + params=request, + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + return response.result + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.get_task_callback( + GetTaskPushNotificationConfigRequest( + params=request, + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + return response.result + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + if not self._config.streaming or not self._card.capabilities.streaming: + raise Exception( + 'client and/or server do not support resubscription.' + ) + async for event in self._transport_client.resubscribe( + TaskResubscriptionRequest( + params=request, + id=str(uuid4()), + ), + http_kwargs=self.get_http_args(context), + context=context, + ): + # Update task, check for errors, etc. + yield event + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + return await self._transport_client.get_card( + http_kwargs=self.get_http_args(context), + context=context, + ) + + +def NewJsonRpcClient( + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], +) -> Client: + """Generator for the `JsonRpcClient` implementation.""" + return JsonRpcClient(card, config, consumers, middleware) diff --git a/src/a2a/client/rest_client.py b/src/a2a/client/rest_client.py new file mode 100644 index 00000000..6e9f0b9b --- /dev/null +++ b/src/a2a/client/rest_client.py @@ -0,0 +1,730 @@ +import json +import logging + +from collections.abc import AsyncGenerator, AsyncIterator +from typing import Any + +import httpx + +from google.protobuf.json_format import MessageToDict, Parse +from httpx_sse import SSEError, aconnect_sse + +from a2a.client.client import A2ACardResolver, Client, ClientConfig, Consumer +from a2a.client.client_task_manager import ClientTaskManager +from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.grpc import a2a_pb2 +from a2a.types import ( + AgentCard, + GetTaskPushNotificationConfigParams, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) +from a2a.utils import proto_utils +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class RestTransportClient: + """A2A Client for interacting with an A2A agent.""" + + def __init__( + self, + httpx_client: httpx.AsyncClient, + agent_card: AgentCard | None = None, + url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + ): + """Initializes the A2AClient. + + Requires either an `AgentCard` or a direct `url` to the agent's RPC endpoint. + + Args: + httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). + agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. + url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. + interceptors: An optional list of client call interceptors to apply to requests. + + Raises: + ValueError: If neither `agent_card` nor `url` is provided. + """ + if agent_card: + self.url = agent_card.url + elif url: + self.url = url + else: + raise ValueError('Must provide either agent_card or url') + # If the url ends in / remove it as this is added by the routes + if self.url.endswith('/'): + self.url = self.url[:-1] + self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + # Indicate if we have captured an extended card details so we can update + # on first call if needed. It is done this way so the caller can setup + # their auth credentials based on the public card and get the updated + # card. + self._needs_extended_card = ( + not agent_card.supportsAuthenticatedExtendedCard + if agent_card + else True + ) + + async def _apply_interceptors( + self, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Applies all registered interceptors to the request.""" + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + # TODO: Implement interceptors for other transports + return final_request_payload, final_http_kwargs + + async def send_message( + self, + request: MessageSendParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> Task | Message: + """Sends a non-streaming message request to the agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `Task` or `Message` object containing the agent's response. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.send_message_config( + request.config + ), + metadata=( + proto_utils.ToProto.metadata(request.metadata) + if request.metadata + else None + ), + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + response_data = await self._send_post_request( + '/v1/message:send', payload, modified_kwargs + ) + response_pb = a2a_pb2.SendMessageResponse() + Parse(response_data, response_pb) + return proto_utils.FromProto.task_or_message(response_pb) + + async def send_message_streaming( + self, + request: MessageSendParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + Objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.send_message_config( + request.configuration + ), + metadata=( + proto_utils.ToProto.metadata(request.metadata) + if request.metadata + else None + ), + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + f'{self.url}/v1/message:stream', + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, + f'Invalid SSE response or protocol error: {e}', + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_post_request( + self, + target: str, + rpc_request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Sends a non-streaming JSON-RPC request to the agent. + + Args: + target: url path + rpc_request_payload: JSON payload for sending the request. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + + Returns: + The JSON response payload as a dictionary. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON. + """ + try: + response = await self.httpx_client.post( + f'{self.url}{target}', + json=rpc_request_payload, + **(http_kwargs or {}), + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def _send_get_request( + self, + target: str, + query_params: dict[str, str], + http_kwargs: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Sends a non-streaming JSON-RPC request to the agent. + + Args: + target: url path + query_params: HTTP query params for the request. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + + Returns: + The JSON response payload as a dictionary. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON. + """ + try: + response = await self.httpx_client.get( + f'{self.url}{target}', + params=query_params, + **(http_kwargs or {}), + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise A2AClientHTTPError(e.response.status_code, str(e)) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_task( + self, + request: TaskQueryParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> Task: + """Retrieves the current state and history of a specific task. + + Args: + request: The `TaskQueryParams` object specifying the task ID and history length. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `Task` object containing the Task. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + # Apply interceptors before sending - only for the http kwargs + payload, modified_kwargs = await self._apply_interceptors( + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.taskId}', + {'historyLength': request.historyLength} + if request.historyLength + else {}, + modified_kwargs, + ) + task = a2a_pb2.Task() + Parse(response_data, task) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> Task: + """Requests the agent to cancel a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `Task` object containing the updated Task with canceled status + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs + ) + task = a2a_pb2.Task() + Parse(response_data, task) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `TaskPushNotificationConfig` object containing the confirmation. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( + parent=f'tasks/{request.taskId}', + config_id=request.pushNotificationConfig.id, + config=proto_utils.ToProto.push_notification_config( + request.pushNotificationConfig + ), + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, http_kwargs, context + ) + response_data = await self._send_post_request( + f'/v1/tasks/{request.taskId}/pushNotificationConfigs/', + payload, + modified_kwargs, + ) + config = a2a_pb2.TaskPushNotificationConfig() + Parse(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `GetTaskPushNotificationConfigParams` object specifying the task ID. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `TaskPushNotificationConfig` object containing the configuration. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.GetTaskPushNotificationConfigRequest( + name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + response_data = await self._send_get_request( + f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', + {}, + modified_kwargs, + ) + config = a2a_pb2.TaskPushNotificationConfig() + Parse(response_data, config) + return proto_utils.FromProto.task_push_notification_config(config) + + async def resubscribe( + self, + request: TaskIdParams, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AsyncGenerator[ + Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message + ]: + """Reconnects to get task updates + + This method uses Server-Sent Events (SSE) to receive a stream of updates from the agent. + + Args: + request: The `TaskIdParams` object containing the task information to reconnect to. + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. + + Yields: + Objects as they are received in the SSE stream. + These can be Task, Message, TaskStatusUpdateEvent, or TaskArtifactUpdateEvent. + + Raises: + A2AClientHTTPError: If an HTTP or SSE protocol error occurs during the request. + A2AClientJSONError: If an SSE event data cannot be decoded as JSON or validated. + """ + pb = a2a_pb2.TaskSubscriptionRequest( + name=f'tasks/{request.id}', + ) + payload = MessageToDict(pb) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + payload, + http_kwargs, + context, + ) + + modified_kwargs.setdefault('timeout', None) + + async with aconnect_sse( + self.httpx_client, + 'POST', + f'{self.url}/v1/tasks/{request.id}:subscribe', + json=payload, + **modified_kwargs, + ) as event_source: + try: + async for sse in event_source.aiter_sse(): + event = a2a_pb2.StreamResponse() + Parse(sse.data, event) + yield proto_utils.FromProto.stream_response(event) + except SSEError as e: + raise A2AClientHTTPError( + 400, + f'Invalid SSE response or protocol error: {e}', + ) from e + except json.JSONDecodeError as e: + raise A2AClientJSONError(str(e)) from e + except httpx.RequestError as e: + raise A2AClientHTTPError( + 503, f'Network communication error: {e}' + ) from e + + async def get_card( + self, + *, + http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, + ) -> AgentCard: + """Retrieves the authenticated card (if necessary) or the public one. + + Args: + http_kwargs: Optional dictionary of keyword arguments to pass to the + underlying httpx.post request. + context: The client call context. + + Returns: + A `AgentCard` object containing the card or an error. + + Raises: + A2AClientHTTPError: If an HTTP error occurs during the request. + A2AClientJSONError: If the response body cannot be decoded as JSON or validated. + """ + # If we don't have the public card, try to get that first. + card = self.agent_card + if not card: + resolver = A2ACardResolver(self.httpx_client, self.url) + card = await resolver.get_agent_card(http_kwargs=http_kwargs) + self._needs_extended_card = card.supportsAuthenticatedExtendedCard + self.agent_card = card + + if not self._needs_extended_card: + return card + + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + '', + http_kwargs, + context, + ) + response_data = await self._send_get_request( + '/v1/card/get', {}, modified_kwargs + ) + card = AgentCard.model_validate(response_data) + self.agent_card = card + self._needs_extended_card = False + return card + + +@trace_class(kind=SpanKind.CLIENT) +class RestClient(Client): + """RestClient is the implementation of the RESTful A2A client. + + This client proxies requests to the RestTransportClient implementation + and manages the REST specific details. If passing additional arguments + in the http.post command, these should be attached to the ClientCallContext + under the dictionary key 'http_kwargs'. + """ + + def __init__( + self, + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], + ): + super().__init__(consumers, middleware) + if not config.httpx_client: + raise Exception('JsonRpc client requires httpx client.') + self._card = card + url = card.url + self._config = config + self._transport_client = RestTransportClient( + config.httpx_client, self._card, url, middleware + ) + + def get_http_args( + self, context: ClientCallContext + ) -> dict[str, Any] | None: + return context.state.get('http_kwargs', None) if context else None + + async def send_message( + self, + request: Message, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + config = MessageSendConfiguration( + accepted_output_modes=self._config.accepted_output_modes, + blocking=not self._config.polling, + push_notification_config=( + self._config.push_notification_configs[0] + if self._config.push_notification_configs + else None + ), + ) + if not self._config.streaming or not self._card.capabilities.streaming: + response = await self._transport_client.send_message( + MessageSendParams( + message=request, + configuration=config, + ), + http_kwargs=self.get_http_args(context), + context=context, + ) + result = ( + response if isinstance(response, Message) else (response, None) + ) + await self.consume(result, self._card) + yield result + return + tracker = ClientTaskManager() + async for event in self._transport_client.send_message_streaming( + MessageSendParams( + message=request, + configuration=config, + ), + http_kwargs=self.get_http_args(context), + context=context, + ): + # Update task, check for errors, etc. + if isinstance(event, Message): + yield event + return + await tracker.process(event) + result = ( + tracker.get_task(), + None if isinstance(event, Task) else event, + ) + await self.consume(result, self._card) + yield result + + async def get_task( + self, + request: TaskQueryParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.get_task( + request, + http_kwargs=self.get_http_args(context), + context=context, + ) + return response + + async def cancel_task( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> Task: + response = await self._transport_client.cancel_task( + request, + http_kwargs=self.get_http_args(context), + context=context, + ) + return response + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.set_task_callback( + request, + http_kwargs=self.get_http_args(context), + context=context, + ) + return response + + async def get_task_callback( + self, + request: GetTaskPushNotificationConfigParams, + *, + context: ClientCallContext | None = None, + ) -> TaskPushNotificationConfig: + response = await self._transport_client.get_task_callback( + request, + http_kwargs=self.get_http_args(context), + context=context, + ) + return response + + async def resubscribe( + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + ) -> AsyncIterator[Task | Message]: + if not self._config.streaming or not self._card.capabilities.streaming: + raise Exception( + 'client and/or server do not support resubscription.' + ) + async for event in self._transport_client.resubscribe( + request, + http_kwargs=self.get_http_args(context), + context=context, + ): + # Update task, check for errors, etc. + yield event + + async def get_card( + self, + *, + context: ClientCallContext | None = None, + ) -> AgentCard: + return await self._transport_client.get_card( + http_kwargs=self.get_http_args(context), + context=context, + ) + + +def NewRestfulClient( + card: AgentCard, + config: ClientConfig, + consumers: list[Consumer], + middleware: list[ClientCallInterceptor], +) -> Client: + """Generator for the `RestClient` implementation.""" + return RestClient(card, config, consumers, middleware) diff --git a/src/a2a/server/apps/rest/rest_app.py b/src/a2a/server/apps/rest/rest_app.py index 717c6e9f..fa9076db 100644 --- a/src/a2a/server/apps/rest/rest_app.py +++ b/src/a2a/server/apps/rest/rest_app.py @@ -130,6 +130,12 @@ async def event_generator( async for item in stream: yield {'data': item} + return EventSourceResponse( + event_generator(method(request, call_context)) + ) + except Exception: + # Since the stream has started, we can't return a JSONResponse. + # Instead, we run the error handling logic (provides logging) return EventSourceResponse( event_generator(method(request, call_context)) ) @@ -180,49 +186,43 @@ async def handle_authenticated_agent_card( def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: routes = { - ('/v1/message:send', 'POST'): ( - functools.partial( - self._handle_request, self.handler.on_message_send - ), + ('/v1/message:send', 'POST'): functools.partial( + self._handle_request, self.handler.on_message_send ), - ('/v1/message:stream', 'POST'): ( - functools.partial( - self._handle_streaming_request, - self.handler.on_message_send_stream, - ), + ('/v1/message:stream', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_message_send_stream, ), - ('/v1/tasks/{id}:subscribe', 'POST'): ( - functools.partial( - self._handle_streaming_request, - self.handler.on_resubscribe_to_task, - ), + ('/v1/tasks/{id}:subscribe', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_resubscribe_to_task, ), - ('/v1/tasks/{id}', 'GET'): ( - functools.partial( - self._handle_request, self.handler.on_get_task - ), + ('/v1/tasks/{id}', 'GET'): functools.partial( + self._handle_request, self.handler.on_get_task ), - ('/v1/tasks/{id}/pushNotificationConfigs/{push_id}', 'GET'): ( - functools.partial( - self._handle_request, self.handler.get_push_notification - ), + ( + '/v1/tasks/{id}/pushNotificationConfigs/{push_id}', + 'GET', + ): functools.partial( + self._handle_request, self.handler.get_push_notification ), - ('/v1/tasks/{id}/pushNotificationConfigs', 'POST'): ( - functools.partial( - self._handle_request, self.handler.set_push_notification - ), + ( + '/v1/tasks/{id}/pushNotificationConfigs', + 'POST', + ): functools.partial( + self._handle_request, self.handler.set_push_notification ), - ('/v1/tasks/{id}/pushNotificationConfigs', 'GET'): ( - functools.partial( - self._handle_request, self.handler.list_push_notifications - ), + ( + '/v1/tasks/{id}/pushNotificationConfigs', + 'GET', + ): functools.partial( + self._handle_request, self.handler.list_push_notifications ), - ('/v1/tasks', 'GET'): ( - functools.partial( - self._handle_request, self.handler.list_tasks - ), + ('/v1/tasks', 'GET'): functools.partial( + self._handle_request, self.handler.list_tasks ), } if self.agent_card.supportsAuthenticatedExtendedCard: - routes['/v1/card'] = (self.handle_authenticated_agent_card, 'GET') + routes[('/v1/card', 'GET')] = self.handle_authenticated_agent_card + return routes diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 930318d2..b226c710 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -215,7 +215,7 @@ async def get_push_notification( id=task_id, push_id=push_id ) else: - params = TaskIdParams['id'] + params = TaskIdParams(id=task_id) config = ( await self.request_handler.on_get_task_push_notification_config( params, context @@ -257,11 +257,9 @@ async def set_push_notification( body = await request.body() params = a2a_pb2.TaskPushNotificationConfig() Parse(body, params) - params = TaskPushNotificationConfig.validate_model(body) - a2a_request = ( - proto_utils.FromProto.task_push_notification_config( - params, - ), + params = TaskPushNotificationConfig.model_validate(body) + a2a_request = proto_utils.FromProto.task_push_notification_config( + params, ) config = ( await self.request_handler.on_set_task_push_notification_config( @@ -293,10 +291,10 @@ async def on_get_task( """ try: task_id = request.path_params['id'] - historyLength = None - if 'historyLength' in request.query_params: - historyLength = request.query_params['historyLength'] - params = TaskQueryParams(id=task_id, historyLength=historyLength) + history_length = request.query_params.get('historyLength', None) + if historyLength: + history_length = int(history_length) + params = TaskQueryParams(id=task_id, history_length=history_length) task = await self.request_handler.on_get_task(params, context) if task: return MessageToJson(proto_utils.ToProto.task(task)) diff --git a/src/a2a/utils/__init__.py b/src/a2a/utils/__init__.py index 06ac1123..f47881a0 100644 --- a/src/a2a/utils/__init__.py +++ b/src/a2a/utils/__init__.py @@ -28,12 +28,14 @@ completed_task, new_task, ) +from a2a.utils.transports import Transports __all__ = [ 'AGENT_CARD_WELL_KNOWN_PATH', 'DEFAULT_RPC_URL', 'EXTENDED_AGENT_CARD_PATH', + 'Transports', 'append_artifact_to_task', 'are_modalities_compatible', 'build_text_artifact', diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index ddaa4f9e..933968c8 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -286,6 +286,23 @@ def agent_card( supports_authenticated_extended_card=bool( card.supports_authenticated_extended_card ), + preferred_transport=card.preferred_transport, + protocol_version=card.protocol_version, + additional_interfaces=[ + cls.agent_interface(x) for x in card.additional_interfaces + ] + if card.additional_interfaces + else None, + ) + + @classmethod + def agent_interface( + cls, + interface: types.AgentInterface, + ) -> a2a_pb2.AgentInterface: + return a2a_pb2.AgentInterface( + transport=interface.transport, + url=interface.url, ) @classmethod @@ -663,6 +680,23 @@ def agent_card( url=card.url, version=card.version, supports_authenticated_extended_card=card.supports_authenticated_extended_card, + preferred_transport=card.preferred_transport, + protocol_version=card.protocol_version, + additional_interfaces=[ + cls.agent_interface(x) for x in card.additional_interfaces + ] + if card.additional_interfaces + else None, + ) + + @classmethod + def agent_interface( + cls, + interface: a2a_pb2.AgentInterface, + ) -> types.AgentInterface: + return types.AgentInterface( + transport=interface.transport, + url=interface.url, ) @classmethod @@ -793,6 +827,24 @@ def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: ), ) + @classmethod + def stream_response( + cls, + response: a2a_pb2.StreamResponse, + ) -> ( + types.Message + | types.Task + | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent + ): + if response.HasField('msg'): + return cls.message(response.msg) + if response.HasField('task'): + return cls.task(response.task) + if response.HasField('status_update'): + return cls.task_status_update_event(response.status_update) + return cls.task_artifact_update_event(response.artifact_update) + @classmethod def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: return types.AgentSkill( diff --git a/src/a2a/utils/transports.py b/src/a2a/utils/transports.py new file mode 100644 index 00000000..50f8aa07 --- /dev/null +++ b/src/a2a/utils/transports.py @@ -0,0 +1,9 @@ +"""Defines standard protocol transport labels.""" + +from enum import Enum + + +class Transports(str, Enum): + GRPC = 'GRPC' + JSONRPC = 'JSONRPC' + RESTful = 'HTTP+JSON'