diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index d7ef1e3c..37a32afd 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -2,10 +2,12 @@ ACard AClient AError AFast +AGrpc ARequest ARun AServer AServers +AService AStarlette EUR GBP @@ -13,9 +15,11 @@ INR JPY JSONRPCt Llm +RUF aconnect adk agentic +aio autouse cla cls @@ -34,6 +38,7 @@ linting oauthoidc opensource protoc +pyi pyversions socio sse diff --git a/.github/actions/spelling/excludes.txt b/.github/actions/spelling/excludes.txt index dbbff998..d4c4eef1 100644 --- a/.github/actions/spelling/excludes.txt +++ b/.github/actions/spelling/excludes.txt @@ -85,7 +85,6 @@ \.zip$ ^\.github/actions/spelling/ ^\.github/workflows/ -^\Qsrc/a2a/auth/__init__.py\E$ -^\Qsrc/a2a/server/request_handlers/context.py\E$ CHANGELOG.md noxfile.py +^src/a2a/grpc/ diff --git a/.github/linters/.jscpd.json b/.github/linters/.jscpd.json index 5e86d6d8..5a6fcad7 100644 --- a/.github/linters/.jscpd.json +++ b/.github/linters/.jscpd.json @@ -1,5 +1,5 @@ { - "ignore": ["**/.github/**", "**/.git/**", "**/tests/**"], + "ignore": ["**/.github/**", "**/.git/**", "**/tests/**", "**/src/a2a/grpc/**", "**/.nox/**", "**/.venv/**"], "threshold": 3, "reporters": ["html", "markdown"] } diff --git a/.github/linters/.ruff.toml b/.github/linters/.ruff.toml index 29c4ff20..34dbfa2b 100644 --- a/.github/linters/.ruff.toml +++ b/.github/linters/.ruff.toml @@ -82,6 +82,7 @@ exclude = [ "venv", "*/migrations/*", "noxfile.py", + "src/a2a/grpc/**", ] [lint.isort] @@ -137,9 +138,14 @@ inline-quotes = "single" "SLF001", ] "types.py" = ["D", "E501", "N815"] # Ignore docstring and annotation issues in types.py +"proto_utils.py" = ["D102", "PLR0911"] +"helpers.py" = ["ANN001", "ANN201", "ANN202"] [format] -exclude = ["types.py"] +exclude = [ + "types.py", + "src/a2a/grpc/**", +] docstring-code-format = true docstring-code-line-length = "dynamic" # Or set to 80 quote-style = "single" diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index a657c5a3..890e81ae 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -64,4 +64,5 @@ jobs: VALIDATE_GIT_COMMITLINT: false PYTHON_MYPY_CONFIG_FILE: .mypy.ini FILTER_REGEX_INCLUDE: ".*src/**/*" + FILTER_REGEX_EXCLUDE: ".*src/a2a/grpc/**/*" PYTHON_RUFF_CONFIG_FILE: .ruff.toml diff --git a/buf.gen.yaml b/buf.gen.yaml index 7102471e..e5e18e65 100644 --- a/buf.gen.yaml +++ b/buf.gen.yaml @@ -19,9 +19,12 @@ managed: plugins: # Generate python protobuf related code # Generates *_pb2.py files, one for each .proto - - remote: buf.build/protocolbuffers/python + - remote: buf.build/protocolbuffers/python:v29.3 out: src/a2a/grpc # Generate python service code. # Generates *_pb2_grpc.py - remote: buf.build/grpc/python out: src/a2a/grpc + # Generates *_pb2.pyi files. + - remote: buf.build/protocolbuffers/pyi:v29.3 + out: src/a2a/grpc diff --git a/noxfile.py b/noxfile.py index 380d2a28..60dd15ef 100644 --- a/noxfile.py +++ b/noxfile.py @@ -103,7 +103,9 @@ def format(session) -> None: } ) - lint_paths_py = [f for f in changed_files if f.endswith('.py')] + lint_paths_py = [ + f for f in changed_files if f.endswith('.py') and 'grpc/' not in f + ] if not lint_paths_py: session.log('No changed Python files to lint.') @@ -111,6 +113,7 @@ def format(session) -> None: session.install( 'types-requests', + 'types-protobuf', 'pyupgrade', 'autoflake', 'ruff', diff --git a/pyproject.toml b/pyproject.toml index 991fc8df..c1e7f4db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,11 +11,15 @@ dependencies = [ "fastapi>=0.115.12", "httpx>=0.28.1", "httpx-sse>=0.4.0", + "google-api-core>=1.26.0", "opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0", "pydantic>=2.11.3", "sse-starlette>=2.3.3", "starlette>=0.46.2", + "grpcio>=1.60", + "grpcio-tools>=1.60", + "grpcio_reflection>=1.7.0", ] classifiers = [ diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 3455c867..e91c9eb7 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -6,6 +6,7 @@ A2AClientHTTPError, A2AClientJSONError, ) +from a2a.client.grpc_client import A2AGrpcClient from a2a.client.helpers import create_text_message_object @@ -15,5 +16,6 @@ 'A2AClientError', 'A2AClientHTTPError', 'A2AClientJSONError', + 'A2AGrpcClient', 'create_text_message_object', ] diff --git a/src/a2a/client/grpc_client.py b/src/a2a/client/grpc_client.py new file mode 100644 index 00000000..d9f14b7f --- /dev/null +++ b/src/a2a/client/grpc_client.py @@ -0,0 +1,190 @@ +import logging + +from collections.abc import AsyncGenerator + +import grpc + +from a2a.grpc import a2a_pb2, a2a_pb2_grpc +from a2a.types import ( + AgentCard, + Message, + MessageSendParams, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskStatusUpdateEvent, +) +from a2a.utils import proto_utils +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.CLIENT) +class A2AGrpcClient: + """A2A Client for interacting with an A2A agent via gRPC.""" + + def __init__( + self, + grpc_stub: a2a_pb2_grpc.A2AServiceStub, + agent_card: AgentCard, + ): + """Initializes the A2AGrpcClient. + + Requires an `AgentCard` + + Args: + grpc_stub: A grpc client stub. + agent_card: The agent card object. + """ + self.agent_card = agent_card + self.stub = grpc_stub + + async def send_message( + self, + request: MessageSendParams, + ) -> Task | Message: + """Sends a non-streaming message request to the agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + + Returns: + A `Task` or `Message` object containing the agent's response. + """ + response = await self.stub.SendMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + if response.task: + return proto_utils.FromProto.task(response.task) + return proto_utils.FromProto.message(response.msg) + + async def send_message_streaming( + self, + request: MessageSendParams, + ) -> AsyncGenerator[ + Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent + ]: + """Sends a streaming message request to the agent and yields responses as they arrive. + + This method uses gRPC streams to receive a stream of updates from the + agent. + + Args: + request: The `MessageSendParams` object containing the message and configuration. + + Yields: + `Message` or `Task` or `TaskStatusUpdateEvent` or + `TaskArtifactUpdateEvent` objects as they are received in the + stream. + """ + stream = self.stub.SendStreamingMessage( + a2a_pb2.SendMessageRequest( + request=proto_utils.ToProto.message(request.message), + configuration=proto_utils.ToProto.message_send_configuration( + request.configuration + ), + metadata=proto_utils.ToProto.metadata(request.metadata), + ) + ) + while True: + response = await stream.read() + if response == grpc.aio.EOF: + break + if response.HasField('msg'): + yield proto_utils.FromProto.message(response.msg) + elif response.HasField('task'): + yield proto_utils.FromProto.task(response.task) + elif response.HasField('status_update'): + yield proto_utils.FromProto.task_status_update_event( + response.status_update + ) + elif response.HasField('artifact_update'): + yield proto_utils.FromProto.task_artifact_update_event( + response.artifact_update + ) + + async def get_task( + self, + request: TaskQueryParams, + ) -> Task: + """Retrieves the current state and history of a specific task. + + Args: + request: The `TaskQueryParams` object specifying the task ID + + Returns: + A `Task` object containing the Task or None. + """ + task = await self.stub.GetTask( + a2a_pb2.GetTaskRequest(name=f'tasks/{request.id}') + ) + return proto_utils.FromProto.task(task) + + async def cancel_task( + self, + request: TaskIdParams, + ) -> Task: + """Requests the agent to cancel a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + + Returns: + A `Task` object containing the updated Task + """ + task = await self.stub.CancelTask( + a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + ) + return proto_utils.FromProto.task(task) + + async def set_task_callback( + self, + request: TaskPushNotificationConfig, + ) -> TaskPushNotificationConfig: + """Sets or updates the push notification configuration for a specific task. + + Args: + request: The `TaskPushNotificationConfig` object specifying the task ID and configuration. + + Returns: + A `TaskPushNotificationConfig` object containing the config. + """ + config = await self.stub.CreateTaskPushNotification( + a2a_pb2.CreateTaskPushNotificationRequest( + parent='', + config_id='', + config=proto_utils.ToProto.task_push_notification_config( + request + ), + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) + + async def get_task_callback( + self, + request: TaskIdParams, # TODO: Update to a push id params + ) -> TaskPushNotificationConfig: + """Retrieves the push notification configuration for a specific task. + + Args: + request: The `TaskIdParams` object specifying the task ID. + + Returns: + A `TaskPushNotificationConfig` object containing the configuration. + """ + config = await self.stub.GetTaskPushNotification( + a2a_pb2.GetTaskPushNotificationRequest( + name=f'tasks/{request.id}/pushNotification/undefined', + ) + ) + return proto_utils.FromProto.task_push_notification_config(config) diff --git a/src/a2a/grpc/__init__.py b/src/a2a/grpc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/a2a/grpc/a2a_pb2.py b/src/a2a/grpc/a2a_pb2.py new file mode 100644 index 00000000..81078b8b --- /dev/null +++ b/src/a2a/grpc/a2a_pb2.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: a2a.proto +# Protobuf Python Version: 5.29.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 3, + '', + 'a2a.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2 +from google.api import client_pb2 as google_dot_api_dot_client__pb2 +from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2 +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ta2a.proto\x12\x06\x61\x32\x61.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17google/api/client.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xde\x01\n\x18SendMessageConfiguration\x12\x32\n\x15\x61\x63\x63\x65pted_output_modes\x18\x01 \x03(\tR\x13\x61\x63\x63\x65ptedOutputModes\x12K\n\x11push_notification\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x10pushNotification\x12%\n\x0ehistory_length\x18\x03 \x01(\x05R\rhistoryLength\x12\x1a\n\x08\x62locking\x18\x04 \x01(\x08R\x08\x62locking\"\xf1\x01\n\x04Task\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12.\n\tartifacts\x18\x04 \x03(\x0b\x32\x10.a2a.v1.ArtifactR\tartifacts\x12)\n\x07history\x18\x05 \x03(\x0b\x32\x0f.a2a.v1.MessageR\x07history\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x98\x01\n\nTaskStatus\x12\'\n\x05state\x18\x01 \x01(\x0e\x32\x11.a2a.v1.TaskStateR\x05state\x12\'\n\x06update\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageR\x06update\x12\x38\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\"t\n\x04Part\x12\x14\n\x04text\x18\x01 \x01(\tH\x00R\x04text\x12&\n\x04\x66ile\x18\x02 \x01(\x0b\x32\x10.a2a.v1.FilePartH\x00R\x04\x66ile\x12&\n\x04\x64\x61ta\x18\x03 \x01(\x0b\x32\x10.a2a.v1.DataPartH\x00R\x04\x64\x61taB\x06\n\x04part\"\x7f\n\x08\x46ilePart\x12$\n\rfile_with_uri\x18\x01 \x01(\tH\x00R\x0b\x66ileWithUri\x12(\n\x0f\x66ile_with_bytes\x18\x02 \x01(\x0cH\x00R\rfileWithBytes\x12\x1b\n\tmime_type\x18\x03 \x01(\tR\x08mimeTypeB\x06\n\x04\x66ile\"7\n\x08\x44\x61taPart\x12+\n\x04\x64\x61ta\x18\x01 \x01(\x0b\x32\x17.google.protobuf.StructR\x04\x64\x61ta\"\xff\x01\n\x07Message\x12\x1d\n\nmessage_id\x18\x01 \x01(\tR\tmessageId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12\x17\n\x07task_id\x18\x03 \x01(\tR\x06taskId\x12 \n\x04role\x18\x04 \x01(\x0e\x32\x0c.a2a.v1.RoleR\x04role\x12&\n\x07\x63ontent\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x07\x63ontent\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xda\x01\n\x08\x41rtifact\x12\x1f\n\x0b\x61rtifact_id\x18\x01 \x01(\tR\nartifactId\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x04 \x01(\tR\x0b\x64\x65scription\x12\"\n\x05parts\x18\x05 \x03(\x0b\x32\x0c.a2a.v1.PartR\x05parts\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\x12\x1e\n\nextensions\x18\x07 \x03(\tR\nextensions\"\xc6\x01\n\x15TaskStatusUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12*\n\x06status\x18\x03 \x01(\x0b\x32\x12.a2a.v1.TaskStatusR\x06status\x12\x14\n\x05\x66inal\x18\x04 \x01(\x08R\x05\x66inal\x12\x33\n\x08metadata\x18\x05 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\xeb\x01\n\x17TaskArtifactUpdateEvent\x12\x17\n\x07task_id\x18\x01 \x01(\tR\x06taskId\x12\x1d\n\ncontext_id\x18\x02 \x01(\tR\tcontextId\x12,\n\x08\x61rtifact\x18\x03 \x01(\x0b\x32\x10.a2a.v1.ArtifactR\x08\x61rtifact\x12\x16\n\x06\x61ppend\x18\x04 \x01(\x08R\x06\x61ppend\x12\x1d\n\nlast_chunk\x18\x05 \x01(\x08R\tlastChunk\x12\x33\n\x08metadata\x18\x06 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"\x94\x01\n\x16PushNotificationConfig\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x10\n\x03url\x18\x02 \x01(\tR\x03url\x12\x14\n\x05token\x18\x03 \x01(\tR\x05token\x12\x42\n\x0e\x61uthentication\x18\x04 \x01(\x0b\x32\x1a.a2a.v1.AuthenticationInfoR\x0e\x61uthentication\"P\n\x12\x41uthenticationInfo\x12\x18\n\x07schemes\x18\x01 \x03(\tR\x07schemes\x12 \n\x0b\x63redentials\x18\x02 \x01(\tR\x0b\x63redentials\"\xc8\x05\n\tAgentCard\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x10\n\x03url\x18\x03 \x01(\tR\x03url\x12\x31\n\x08provider\x18\x04 \x01(\x0b\x32\x15.a2a.v1.AgentProviderR\x08provider\x12\x18\n\x07version\x18\x05 \x01(\tR\x07version\x12+\n\x11\x64ocumentation_url\x18\x06 \x01(\tR\x10\x64ocumentationUrl\x12=\n\x0c\x63\x61pabilities\x18\x07 \x01(\x0b\x32\x19.a2a.v1.AgentCapabilitiesR\x0c\x63\x61pabilities\x12Q\n\x10security_schemes\x18\x08 \x03(\x0b\x32&.a2a.v1.AgentCard.SecuritySchemesEntryR\x0fsecuritySchemes\x12,\n\x08security\x18\t \x03(\x0b\x32\x10.a2a.v1.SecurityR\x08security\x12.\n\x13\x64\x65\x66\x61ult_input_modes\x18\n \x03(\tR\x11\x64\x65\x66\x61ultInputModes\x12\x30\n\x14\x64\x65\x66\x61ult_output_modes\x18\x0b \x03(\tR\x12\x64\x65\x66\x61ultOutputModes\x12*\n\x06skills\x18\x0c \x03(\x0b\x32\x12.a2a.v1.AgentSkillR\x06skills\x12O\n$supports_authenticated_extended_card\x18\r \x01(\x08R!supportsAuthenticatedExtendedCard\x1aZ\n\x14SecuritySchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x16.a2a.v1.SecuritySchemeR\x05value:\x02\x38\x01\"E\n\rAgentProvider\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\"\n\x0corganization\x18\x02 \x01(\tR\x0corganization\"\x98\x01\n\x11\x41gentCapabilities\x12\x1c\n\tstreaming\x18\x01 \x01(\x08R\tstreaming\x12-\n\x12push_notifications\x18\x02 \x01(\x08R\x11pushNotifications\x12\x36\n\nextensions\x18\x03 \x03(\x0b\x32\x16.a2a.v1.AgentExtensionR\nextensions\"\x91\x01\n\x0e\x41gentExtension\x12\x10\n\x03uri\x18\x01 \x01(\tR\x03uri\x12 \n\x0b\x64\x65scription\x18\x02 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08required\x18\x03 \x01(\x08R\x08required\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.google.protobuf.StructR\x06params\"\xc6\x01\n\nAgentSkill\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12 \n\x0b\x64\x65scription\x18\x03 \x01(\tR\x0b\x64\x65scription\x12\x12\n\x04tags\x18\x04 \x03(\tR\x04tags\x12\x1a\n\x08\x65xamples\x18\x05 \x03(\tR\x08\x65xamples\x12\x1f\n\x0binput_modes\x18\x06 \x03(\tR\ninputModes\x12!\n\x0coutput_modes\x18\x07 \x03(\tR\x0boutputModes\"\x8a\x01\n\x1aTaskPushNotificationConfig\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12X\n\x18push_notification_config\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.PushNotificationConfigR\x16pushNotificationConfig\" \n\nStringList\x12\x12\n\x04list\x18\x01 \x03(\tR\x04list\"\x93\x01\n\x08Security\x12\x37\n\x07schemes\x18\x01 \x03(\x0b\x32\x1d.a2a.v1.Security.SchemesEntryR\x07schemes\x1aN\n\x0cSchemesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x12.a2a.v1.StringListR\x05value:\x02\x38\x01\"\x91\x03\n\x0eSecurityScheme\x12U\n\x17\x61pi_key_security_scheme\x18\x01 \x01(\x0b\x32\x1c.a2a.v1.APIKeySecuritySchemeH\x00R\x14\x61piKeySecurityScheme\x12[\n\x19http_auth_security_scheme\x18\x02 \x01(\x0b\x32\x1e.a2a.v1.HTTPAuthSecuritySchemeH\x00R\x16httpAuthSecurityScheme\x12T\n\x16oauth2_security_scheme\x18\x03 \x01(\x0b\x32\x1c.a2a.v1.OAuth2SecuritySchemeH\x00R\x14oauth2SecurityScheme\x12k\n\x1fopen_id_connect_security_scheme\x18\x04 \x01(\x0b\x32#.a2a.v1.OpenIdConnectSecuritySchemeH\x00R\x1bopenIdConnectSecuritySchemeB\x08\n\x06scheme\"h\n\x14\x41PIKeySecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x1a\n\x08location\x18\x02 \x01(\tR\x08location\x12\x12\n\x04name\x18\x03 \x01(\tR\x04name\"w\n\x16HTTPAuthSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12\x16\n\x06scheme\x18\x02 \x01(\tR\x06scheme\x12#\n\rbearer_format\x18\x03 \x01(\tR\x0c\x62\x65\x61rerFormat\"b\n\x14OAuth2SecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12(\n\x05\x66lows\x18\x02 \x01(\x0b\x32\x12.a2a.v1.OAuthFlowsR\x05\x66lows\"n\n\x1bOpenIdConnectSecurityScheme\x12 \n\x0b\x64\x65scription\x18\x01 \x01(\tR\x0b\x64\x65scription\x12-\n\x13open_id_connect_url\x18\x02 \x01(\tR\x10openIdConnectUrl\"\xb0\x02\n\nOAuthFlows\x12S\n\x12\x61uthorization_code\x18\x01 \x01(\x0b\x32\".a2a.v1.AuthorizationCodeOAuthFlowH\x00R\x11\x61uthorizationCode\x12S\n\x12\x63lient_credentials\x18\x02 \x01(\x0b\x32\".a2a.v1.ClientCredentialsOAuthFlowH\x00R\x11\x63lientCredentials\x12\x37\n\x08implicit\x18\x03 \x01(\x0b\x32\x19.a2a.v1.ImplicitOAuthFlowH\x00R\x08implicit\x12\x37\n\x08password\x18\x04 \x01(\x0b\x32\x19.a2a.v1.PasswordOAuthFlowH\x00R\x08passwordB\x06\n\x04\x66low\"\x8a\x02\n\x1a\x41uthorizationCodeOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1b\n\ttoken_url\x18\x02 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x03 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x04 \x03(\x0b\x32..a2a.v1.AuthorizationCodeOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdd\x01\n\x1a\x43lientCredentialsOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12\x46\n\x06scopes\x18\x03 \x03(\x0b\x32..a2a.v1.ClientCredentialsOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xdb\x01\n\x11ImplicitOAuthFlow\x12+\n\x11\x61uthorization_url\x18\x01 \x01(\tR\x10\x61uthorizationUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.ImplicitOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xcb\x01\n\x11PasswordOAuthFlow\x12\x1b\n\ttoken_url\x18\x01 \x01(\tR\x08tokenUrl\x12\x1f\n\x0brefresh_url\x18\x02 \x01(\tR\nrefreshUrl\x12=\n\x06scopes\x18\x03 \x03(\x0b\x32%.a2a.v1.PasswordOAuthFlow.ScopesEntryR\x06scopes\x1a\x39\n\x0bScopesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\xc1\x01\n\x12SendMessageRequest\x12.\n\x07request\x18\x01 \x01(\x0b\x32\x0f.a2a.v1.MessageB\x03\xe0\x41\x02R\x07request\x12\x46\n\rconfiguration\x18\x02 \x01(\x0b\x32 .a2a.v1.SendMessageConfigurationR\rconfiguration\x12\x33\n\x08metadata\x18\x03 \x01(\x0b\x32\x17.google.protobuf.StructR\x08metadata\"P\n\x0eGetTaskRequest\x12\x17\n\x04name\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x04name\x12%\n\x0ehistory_length\x18\x02 \x01(\x05R\rhistoryLength\"\'\n\x11\x43\x61ncelTaskRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"4\n\x1eGetTaskPushNotificationRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"\xa3\x01\n!CreateTaskPushNotificationRequest\x12\x1b\n\x06parent\x18\x01 \x01(\tB\x03\xe0\x41\x02R\x06parent\x12 \n\tconfig_id\x18\x02 \x01(\tB\x03\xe0\x41\x02R\x08\x63onfigId\x12?\n\x06\x63onfig\x18\x03 \x01(\x0b\x32\".a2a.v1.TaskPushNotificationConfigB\x03\xe0\x41\x02R\x06\x63onfig\"-\n\x17TaskSubscriptionRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\"u\n\x1fListTaskPushNotificationRequest\x12\x16\n\x06parent\x18\x01 \x01(\tR\x06parent\x12\x1b\n\tpage_size\x18\x02 \x01(\x05R\x08pageSize\x12\x1d\n\npage_token\x18\x03 \x01(\tR\tpageToken\"\x15\n\x13GetAgentCardRequest\"i\n\x13SendMessageResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12#\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x03msgB\t\n\x07payload\"\xf6\x01\n\x0eStreamResponse\x12\"\n\x04task\x18\x01 \x01(\x0b\x32\x0c.a2a.v1.TaskH\x00R\x04task\x12#\n\x03msg\x18\x02 \x01(\x0b\x32\x0f.a2a.v1.MessageH\x00R\x03msg\x12\x44\n\rstatus_update\x18\x03 \x01(\x0b\x32\x1d.a2a.v1.TaskStatusUpdateEventH\x00R\x0cstatusUpdate\x12J\n\x0f\x61rtifact_update\x18\x04 \x01(\x0b\x32\x1f.a2a.v1.TaskArtifactUpdateEventH\x00R\x0e\x61rtifactUpdateB\t\n\x07payload\"\x88\x01\n ListTaskPushNotificationResponse\x12<\n\x07\x63onfigs\x18\x01 \x03(\x0b\x32\".a2a.v1.TaskPushNotificationConfigR\x07\x63onfigs\x12&\n\x0fnext_page_token\x18\x02 \x01(\tR\rnextPageToken*\xfa\x01\n\tTaskState\x12\x1a\n\x16TASK_STATE_UNSPECIFIED\x10\x00\x12\x18\n\x14TASK_STATE_SUBMITTED\x10\x01\x12\x16\n\x12TASK_STATE_WORKING\x10\x02\x12\x18\n\x14TASK_STATE_COMPLETED\x10\x03\x12\x15\n\x11TASK_STATE_FAILED\x10\x04\x12\x18\n\x14TASK_STATE_CANCELLED\x10\x05\x12\x1d\n\x19TASK_STATE_INPUT_REQUIRED\x10\x06\x12\x17\n\x13TASK_STATE_REJECTED\x10\x07\x12\x1c\n\x18TASK_STATE_AUTH_REQUIRED\x10\x08*;\n\x04Role\x12\x14\n\x10ROLE_UNSPECIFIED\x10\x00\x12\r\n\tROLE_USER\x10\x01\x12\x0e\n\nROLE_AGENT\x10\x02\x32\xd0\x08\n\nA2AService\x12\x64\n\x0bSendMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x1b.a2a.v1.SendMessageResponse\"\x1c\x82\xd3\xe4\x93\x02\x16\"\x11/v1//message:send:\x01*\x12k\n\x14SendStreamingMessage\x12\x1a.a2a.v1.SendMessageRequest\x1a\x16.a2a.v1.StreamResponse\"\x1d\x82\xd3\xe4\x93\x02\x17\"\x12/v1/message:stream:\x01*0\x01\x12R\n\x07GetTask\x12\x16.a2a.v1.GetTaskRequest\x1a\x0c.a2a.v1.Task\"!\xda\x41\x04name\x82\xd3\xe4\x93\x02\x14\x12\x12/v1/{name=tasks/*}\x12W\n\nCancelTask\x12\x19.a2a.v1.CancelTaskRequest\x1a\x0c.a2a.v1.Task\" \x82\xd3\xe4\x93\x02\x1a\"\x15/v1/tasks/{id}:cancel:\x01*\x12s\n\x10TaskSubscription\x12\x1f.a2a.v1.TaskSubscriptionRequest\x1a\x16.a2a.v1.StreamResponse\"$\x82\xd3\xe4\x93\x02\x1e\x12\x1c/v1/{name=tasks/*}:subscribe0\x01\x12\xb2\x01\n\x1a\x43reateTaskPushNotification\x12).a2a.v1.CreateTaskPushNotificationRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"E\xda\x41\rparent,config\x82\xd3\xe4\x93\x02/\"%/v1/{parent=task/*/pushNotifications}:\x06\x63onfig\x12\x9c\x01\n\x17GetTaskPushNotification\x12&.a2a.v1.GetTaskPushNotificationRequest\x1a\".a2a.v1.TaskPushNotificationConfig\"5\xda\x41\x04name\x82\xd3\xe4\x93\x02(\x12&/v1/{name=tasks/*/pushNotifications/*}\x12\xa6\x01\n\x18ListTaskPushNotification\x12\'.a2a.v1.ListTaskPushNotificationRequest\x1a(.a2a.v1.ListTaskPushNotificationResponse\"7\xda\x41\x06parent\x82\xd3\xe4\x93\x02(\x12&/v1/{parent=tasks/*}/pushNotifications\x12P\n\x0cGetAgentCard\x12\x1b.a2a.v1.GetAgentCardRequest\x1a\x11.a2a.v1.AgentCard\"\x10\x82\xd3\xe4\x93\x02\n\x12\x08/v1/cardBi\n\ncom.a2a.v1B\x08\x41\x32\x61ProtoP\x01Z\x18google.golang.org/a2a/v1\xa2\x02\x03\x41XX\xaa\x02\x06\x41\x32\x61.V1\xca\x02\x06\x41\x32\x61\\V1\xe2\x02\x12\x41\x32\x61\\V1\\GPBMetadata\xea\x02\x07\x41\x32\x61::V1b\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'a2a_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\ncom.a2a.v1B\010A2aProtoP\001Z\030google.golang.org/a2a/v1\242\002\003AXX\252\002\006A2a.V1\312\002\006A2a\\V1\342\002\022A2a\\V1\\GPBMetadata\352\002\007A2a::V1' + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._loaded_options = None + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_options = b'8\001' + _globals['_SECURITY_SCHEMESENTRY']._loaded_options = None + _globals['_SECURITY_SCHEMESENTRY']._serialized_options = b'8\001' + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._loaded_options = None + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._serialized_options = b'8\001' + _globals['_SENDMESSAGEREQUEST'].fields_by_name['request']._loaded_options = None + _globals['_SENDMESSAGEREQUEST'].fields_by_name['request']._serialized_options = b'\340A\002' + _globals['_GETTASKREQUEST'].fields_by_name['name']._loaded_options = None + _globals['_GETTASKREQUEST'].fields_by_name['name']._serialized_options = b'\340A\002' + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['parent']._loaded_options = None + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['parent']._serialized_options = b'\340A\002' + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config_id']._loaded_options = None + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config_id']._serialized_options = b'\340A\002' + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config']._loaded_options = None + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST'].fields_by_name['config']._serialized_options = b'\340A\002' + _globals['_A2ASERVICE'].methods_by_name['SendMessage']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['SendMessage']._serialized_options = b'\202\323\344\223\002\026\"\021/v1//message:send:\001*' + _globals['_A2ASERVICE'].methods_by_name['SendStreamingMessage']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['SendStreamingMessage']._serialized_options = b'\202\323\344\223\002\027\"\022/v1/message:stream:\001*' + _globals['_A2ASERVICE'].methods_by_name['GetTask']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['GetTask']._serialized_options = b'\332A\004name\202\323\344\223\002\024\022\022/v1/{name=tasks/*}' + _globals['_A2ASERVICE'].methods_by_name['CancelTask']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['CancelTask']._serialized_options = b'\202\323\344\223\002\032\"\025/v1/tasks/{id}:cancel:\001*' + _globals['_A2ASERVICE'].methods_by_name['TaskSubscription']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['TaskSubscription']._serialized_options = b'\202\323\344\223\002\036\022\034/v1/{name=tasks/*}:subscribe' + _globals['_A2ASERVICE'].methods_by_name['CreateTaskPushNotification']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['CreateTaskPushNotification']._serialized_options = b'\332A\rparent,config\202\323\344\223\002/\"%/v1/{parent=task/*/pushNotifications}:\006config' + _globals['_A2ASERVICE'].methods_by_name['GetTaskPushNotification']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['GetTaskPushNotification']._serialized_options = b'\332A\004name\202\323\344\223\002(\022&/v1/{name=tasks/*/pushNotifications/*}' + _globals['_A2ASERVICE'].methods_by_name['ListTaskPushNotification']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['ListTaskPushNotification']._serialized_options = b'\332A\006parent\202\323\344\223\002(\022&/v1/{parent=tasks/*}/pushNotifications' + _globals['_A2ASERVICE'].methods_by_name['GetAgentCard']._loaded_options = None + _globals['_A2ASERVICE'].methods_by_name['GetAgentCard']._serialized_options = b'\202\323\344\223\002\n\022\010/v1/card' + _globals['_TASKSTATE']._serialized_start=7161 + _globals['_TASKSTATE']._serialized_end=7411 + _globals['_ROLE']._serialized_start=7413 + _globals['_ROLE']._serialized_end=7472 + _globals['_SENDMESSAGECONFIGURATION']._serialized_start=173 + _globals['_SENDMESSAGECONFIGURATION']._serialized_end=395 + _globals['_TASK']._serialized_start=398 + _globals['_TASK']._serialized_end=639 + _globals['_TASKSTATUS']._serialized_start=642 + _globals['_TASKSTATUS']._serialized_end=794 + _globals['_PART']._serialized_start=796 + _globals['_PART']._serialized_end=912 + _globals['_FILEPART']._serialized_start=914 + _globals['_FILEPART']._serialized_end=1041 + _globals['_DATAPART']._serialized_start=1043 + _globals['_DATAPART']._serialized_end=1098 + _globals['_MESSAGE']._serialized_start=1101 + _globals['_MESSAGE']._serialized_end=1356 + _globals['_ARTIFACT']._serialized_start=1359 + _globals['_ARTIFACT']._serialized_end=1577 + _globals['_TASKSTATUSUPDATEEVENT']._serialized_start=1580 + _globals['_TASKSTATUSUPDATEEVENT']._serialized_end=1778 + _globals['_TASKARTIFACTUPDATEEVENT']._serialized_start=1781 + _globals['_TASKARTIFACTUPDATEEVENT']._serialized_end=2016 + _globals['_PUSHNOTIFICATIONCONFIG']._serialized_start=2019 + _globals['_PUSHNOTIFICATIONCONFIG']._serialized_end=2167 + _globals['_AUTHENTICATIONINFO']._serialized_start=2169 + _globals['_AUTHENTICATIONINFO']._serialized_end=2249 + _globals['_AGENTCARD']._serialized_start=2252 + _globals['_AGENTCARD']._serialized_end=2964 + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_start=2874 + _globals['_AGENTCARD_SECURITYSCHEMESENTRY']._serialized_end=2964 + _globals['_AGENTPROVIDER']._serialized_start=2966 + _globals['_AGENTPROVIDER']._serialized_end=3035 + _globals['_AGENTCAPABILITIES']._serialized_start=3038 + _globals['_AGENTCAPABILITIES']._serialized_end=3190 + _globals['_AGENTEXTENSION']._serialized_start=3193 + _globals['_AGENTEXTENSION']._serialized_end=3338 + _globals['_AGENTSKILL']._serialized_start=3341 + _globals['_AGENTSKILL']._serialized_end=3539 + _globals['_TASKPUSHNOTIFICATIONCONFIG']._serialized_start=3542 + _globals['_TASKPUSHNOTIFICATIONCONFIG']._serialized_end=3680 + _globals['_STRINGLIST']._serialized_start=3682 + _globals['_STRINGLIST']._serialized_end=3714 + _globals['_SECURITY']._serialized_start=3717 + _globals['_SECURITY']._serialized_end=3864 + _globals['_SECURITY_SCHEMESENTRY']._serialized_start=3786 + _globals['_SECURITY_SCHEMESENTRY']._serialized_end=3864 + _globals['_SECURITYSCHEME']._serialized_start=3867 + _globals['_SECURITYSCHEME']._serialized_end=4268 + _globals['_APIKEYSECURITYSCHEME']._serialized_start=4270 + _globals['_APIKEYSECURITYSCHEME']._serialized_end=4374 + _globals['_HTTPAUTHSECURITYSCHEME']._serialized_start=4376 + _globals['_HTTPAUTHSECURITYSCHEME']._serialized_end=4495 + _globals['_OAUTH2SECURITYSCHEME']._serialized_start=4497 + _globals['_OAUTH2SECURITYSCHEME']._serialized_end=4595 + _globals['_OPENIDCONNECTSECURITYSCHEME']._serialized_start=4597 + _globals['_OPENIDCONNECTSECURITYSCHEME']._serialized_end=4707 + _globals['_OAUTHFLOWS']._serialized_start=4710 + _globals['_OAUTHFLOWS']._serialized_end=5014 + _globals['_AUTHORIZATIONCODEOAUTHFLOW']._serialized_start=5017 + _globals['_AUTHORIZATIONCODEOAUTHFLOW']._serialized_end=5283 + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_AUTHORIZATIONCODEOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_CLIENTCREDENTIALSOAUTHFLOW']._serialized_start=5286 + _globals['_CLIENTCREDENTIALSOAUTHFLOW']._serialized_end=5507 + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_CLIENTCREDENTIALSOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_IMPLICITOAUTHFLOW']._serialized_start=5510 + _globals['_IMPLICITOAUTHFLOW']._serialized_end=5729 + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_IMPLICITOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_PASSWORDOAUTHFLOW']._serialized_start=5732 + _globals['_PASSWORDOAUTHFLOW']._serialized_end=5935 + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._serialized_start=5226 + _globals['_PASSWORDOAUTHFLOW_SCOPESENTRY']._serialized_end=5283 + _globals['_SENDMESSAGEREQUEST']._serialized_start=5938 + _globals['_SENDMESSAGEREQUEST']._serialized_end=6131 + _globals['_GETTASKREQUEST']._serialized_start=6133 + _globals['_GETTASKREQUEST']._serialized_end=6213 + _globals['_CANCELTASKREQUEST']._serialized_start=6215 + _globals['_CANCELTASKREQUEST']._serialized_end=6254 + _globals['_GETTASKPUSHNOTIFICATIONREQUEST']._serialized_start=6256 + _globals['_GETTASKPUSHNOTIFICATIONREQUEST']._serialized_end=6308 + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST']._serialized_start=6311 + _globals['_CREATETASKPUSHNOTIFICATIONREQUEST']._serialized_end=6474 + _globals['_TASKSUBSCRIPTIONREQUEST']._serialized_start=6476 + _globals['_TASKSUBSCRIPTIONREQUEST']._serialized_end=6521 + _globals['_LISTTASKPUSHNOTIFICATIONREQUEST']._serialized_start=6523 + _globals['_LISTTASKPUSHNOTIFICATIONREQUEST']._serialized_end=6640 + _globals['_GETAGENTCARDREQUEST']._serialized_start=6642 + _globals['_GETAGENTCARDREQUEST']._serialized_end=6663 + _globals['_SENDMESSAGERESPONSE']._serialized_start=6665 + _globals['_SENDMESSAGERESPONSE']._serialized_end=6770 + _globals['_STREAMRESPONSE']._serialized_start=6773 + _globals['_STREAMRESPONSE']._serialized_end=7019 + _globals['_LISTTASKPUSHNOTIFICATIONRESPONSE']._serialized_start=7022 + _globals['_LISTTASKPUSHNOTIFICATIONRESPONSE']._serialized_end=7158 + _globals['_A2ASERVICE']._serialized_start=7475 + _globals['_A2ASERVICE']._serialized_end=8579 +# @@protoc_insertion_point(module_scope) diff --git a/src/a2a/grpc/a2a_pb2.pyi b/src/a2a/grpc/a2a_pb2.pyi new file mode 100644 index 00000000..8d2fad9b --- /dev/null +++ b/src/a2a/grpc/a2a_pb2.pyi @@ -0,0 +1,520 @@ +from google.api import annotations_pb2 as _annotations_pb2 +from google.api import client_pb2 as _client_pb2 +from google.api import field_behavior_pb2 as _field_behavior_pb2 +from google.protobuf import struct_pb2 as _struct_pb2 +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class TaskState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + TASK_STATE_UNSPECIFIED: _ClassVar[TaskState] + TASK_STATE_SUBMITTED: _ClassVar[TaskState] + TASK_STATE_WORKING: _ClassVar[TaskState] + TASK_STATE_COMPLETED: _ClassVar[TaskState] + TASK_STATE_FAILED: _ClassVar[TaskState] + TASK_STATE_CANCELLED: _ClassVar[TaskState] + TASK_STATE_INPUT_REQUIRED: _ClassVar[TaskState] + TASK_STATE_REJECTED: _ClassVar[TaskState] + TASK_STATE_AUTH_REQUIRED: _ClassVar[TaskState] + +class Role(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + ROLE_UNSPECIFIED: _ClassVar[Role] + ROLE_USER: _ClassVar[Role] + ROLE_AGENT: _ClassVar[Role] +TASK_STATE_UNSPECIFIED: TaskState +TASK_STATE_SUBMITTED: TaskState +TASK_STATE_WORKING: TaskState +TASK_STATE_COMPLETED: TaskState +TASK_STATE_FAILED: TaskState +TASK_STATE_CANCELLED: TaskState +TASK_STATE_INPUT_REQUIRED: TaskState +TASK_STATE_REJECTED: TaskState +TASK_STATE_AUTH_REQUIRED: TaskState +ROLE_UNSPECIFIED: Role +ROLE_USER: Role +ROLE_AGENT: Role + +class SendMessageConfiguration(_message.Message): + __slots__ = ("accepted_output_modes", "push_notification", "history_length", "blocking") + ACCEPTED_OUTPUT_MODES_FIELD_NUMBER: _ClassVar[int] + PUSH_NOTIFICATION_FIELD_NUMBER: _ClassVar[int] + HISTORY_LENGTH_FIELD_NUMBER: _ClassVar[int] + BLOCKING_FIELD_NUMBER: _ClassVar[int] + accepted_output_modes: _containers.RepeatedScalarFieldContainer[str] + push_notification: PushNotificationConfig + history_length: int + blocking: bool + def __init__(self, accepted_output_modes: _Optional[_Iterable[str]] = ..., push_notification: _Optional[_Union[PushNotificationConfig, _Mapping]] = ..., history_length: _Optional[int] = ..., blocking: bool = ...) -> None: ... + +class Task(_message.Message): + __slots__ = ("id", "context_id", "status", "artifacts", "history", "metadata") + ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + ARTIFACTS_FIELD_NUMBER: _ClassVar[int] + HISTORY_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + id: str + context_id: str + status: TaskStatus + artifacts: _containers.RepeatedCompositeFieldContainer[Artifact] + history: _containers.RepeatedCompositeFieldContainer[Message] + metadata: _struct_pb2.Struct + def __init__(self, id: _Optional[str] = ..., context_id: _Optional[str] = ..., status: _Optional[_Union[TaskStatus, _Mapping]] = ..., artifacts: _Optional[_Iterable[_Union[Artifact, _Mapping]]] = ..., history: _Optional[_Iterable[_Union[Message, _Mapping]]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class TaskStatus(_message.Message): + __slots__ = ("state", "update", "timestamp") + STATE_FIELD_NUMBER: _ClassVar[int] + UPDATE_FIELD_NUMBER: _ClassVar[int] + TIMESTAMP_FIELD_NUMBER: _ClassVar[int] + state: TaskState + update: Message + timestamp: _timestamp_pb2.Timestamp + def __init__(self, state: _Optional[_Union[TaskState, str]] = ..., update: _Optional[_Union[Message, _Mapping]] = ..., timestamp: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... + +class Part(_message.Message): + __slots__ = ("text", "file", "data") + TEXT_FIELD_NUMBER: _ClassVar[int] + FILE_FIELD_NUMBER: _ClassVar[int] + DATA_FIELD_NUMBER: _ClassVar[int] + text: str + file: FilePart + data: DataPart + def __init__(self, text: _Optional[str] = ..., file: _Optional[_Union[FilePart, _Mapping]] = ..., data: _Optional[_Union[DataPart, _Mapping]] = ...) -> None: ... + +class FilePart(_message.Message): + __slots__ = ("file_with_uri", "file_with_bytes", "mime_type") + FILE_WITH_URI_FIELD_NUMBER: _ClassVar[int] + FILE_WITH_BYTES_FIELD_NUMBER: _ClassVar[int] + MIME_TYPE_FIELD_NUMBER: _ClassVar[int] + file_with_uri: str + file_with_bytes: bytes + mime_type: str + def __init__(self, file_with_uri: _Optional[str] = ..., file_with_bytes: _Optional[bytes] = ..., mime_type: _Optional[str] = ...) -> None: ... + +class DataPart(_message.Message): + __slots__ = ("data",) + DATA_FIELD_NUMBER: _ClassVar[int] + data: _struct_pb2.Struct + def __init__(self, data: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class Message(_message.Message): + __slots__ = ("message_id", "context_id", "task_id", "role", "content", "metadata", "extensions") + MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + TASK_ID_FIELD_NUMBER: _ClassVar[int] + ROLE_FIELD_NUMBER: _ClassVar[int] + CONTENT_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + EXTENSIONS_FIELD_NUMBER: _ClassVar[int] + message_id: str + context_id: str + task_id: str + role: Role + content: _containers.RepeatedCompositeFieldContainer[Part] + metadata: _struct_pb2.Struct + extensions: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, message_id: _Optional[str] = ..., context_id: _Optional[str] = ..., task_id: _Optional[str] = ..., role: _Optional[_Union[Role, str]] = ..., content: _Optional[_Iterable[_Union[Part, _Mapping]]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., extensions: _Optional[_Iterable[str]] = ...) -> None: ... + +class Artifact(_message.Message): + __slots__ = ("artifact_id", "name", "description", "parts", "metadata", "extensions") + ARTIFACT_ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + PARTS_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + EXTENSIONS_FIELD_NUMBER: _ClassVar[int] + artifact_id: str + name: str + description: str + parts: _containers.RepeatedCompositeFieldContainer[Part] + metadata: _struct_pb2.Struct + extensions: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, artifact_id: _Optional[str] = ..., name: _Optional[str] = ..., description: _Optional[str] = ..., parts: _Optional[_Iterable[_Union[Part, _Mapping]]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., extensions: _Optional[_Iterable[str]] = ...) -> None: ... + +class TaskStatusUpdateEvent(_message.Message): + __slots__ = ("task_id", "context_id", "status", "final", "metadata") + TASK_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + STATUS_FIELD_NUMBER: _ClassVar[int] + FINAL_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + task_id: str + context_id: str + status: TaskStatus + final: bool + metadata: _struct_pb2.Struct + def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., status: _Optional[_Union[TaskStatus, _Mapping]] = ..., final: bool = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class TaskArtifactUpdateEvent(_message.Message): + __slots__ = ("task_id", "context_id", "artifact", "append", "last_chunk", "metadata") + TASK_ID_FIELD_NUMBER: _ClassVar[int] + CONTEXT_ID_FIELD_NUMBER: _ClassVar[int] + ARTIFACT_FIELD_NUMBER: _ClassVar[int] + APPEND_FIELD_NUMBER: _ClassVar[int] + LAST_CHUNK_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + task_id: str + context_id: str + artifact: Artifact + append: bool + last_chunk: bool + metadata: _struct_pb2.Struct + def __init__(self, task_id: _Optional[str] = ..., context_id: _Optional[str] = ..., artifact: _Optional[_Union[Artifact, _Mapping]] = ..., append: bool = ..., last_chunk: bool = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class PushNotificationConfig(_message.Message): + __slots__ = ("id", "url", "token", "authentication") + ID_FIELD_NUMBER: _ClassVar[int] + URL_FIELD_NUMBER: _ClassVar[int] + TOKEN_FIELD_NUMBER: _ClassVar[int] + AUTHENTICATION_FIELD_NUMBER: _ClassVar[int] + id: str + url: str + token: str + authentication: AuthenticationInfo + def __init__(self, id: _Optional[str] = ..., url: _Optional[str] = ..., token: _Optional[str] = ..., authentication: _Optional[_Union[AuthenticationInfo, _Mapping]] = ...) -> None: ... + +class AuthenticationInfo(_message.Message): + __slots__ = ("schemes", "credentials") + SCHEMES_FIELD_NUMBER: _ClassVar[int] + CREDENTIALS_FIELD_NUMBER: _ClassVar[int] + schemes: _containers.RepeatedScalarFieldContainer[str] + credentials: str + def __init__(self, schemes: _Optional[_Iterable[str]] = ..., credentials: _Optional[str] = ...) -> None: ... + +class AgentCard(_message.Message): + __slots__ = ("name", "description", "url", "provider", "version", "documentation_url", "capabilities", "security_schemes", "security", "default_input_modes", "default_output_modes", "skills", "supports_authenticated_extended_card") + class SecuritySchemesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: SecurityScheme + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[SecurityScheme, _Mapping]] = ...) -> None: ... + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + URL_FIELD_NUMBER: _ClassVar[int] + PROVIDER_FIELD_NUMBER: _ClassVar[int] + VERSION_FIELD_NUMBER: _ClassVar[int] + DOCUMENTATION_URL_FIELD_NUMBER: _ClassVar[int] + CAPABILITIES_FIELD_NUMBER: _ClassVar[int] + SECURITY_SCHEMES_FIELD_NUMBER: _ClassVar[int] + SECURITY_FIELD_NUMBER: _ClassVar[int] + DEFAULT_INPUT_MODES_FIELD_NUMBER: _ClassVar[int] + DEFAULT_OUTPUT_MODES_FIELD_NUMBER: _ClassVar[int] + SKILLS_FIELD_NUMBER: _ClassVar[int] + SUPPORTS_AUTHENTICATED_EXTENDED_CARD_FIELD_NUMBER: _ClassVar[int] + name: str + description: str + url: str + provider: AgentProvider + version: str + documentation_url: str + capabilities: AgentCapabilities + security_schemes: _containers.MessageMap[str, SecurityScheme] + security: _containers.RepeatedCompositeFieldContainer[Security] + default_input_modes: _containers.RepeatedScalarFieldContainer[str] + default_output_modes: _containers.RepeatedScalarFieldContainer[str] + skills: _containers.RepeatedCompositeFieldContainer[AgentSkill] + supports_authenticated_extended_card: bool + def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., url: _Optional[str] = ..., provider: _Optional[_Union[AgentProvider, _Mapping]] = ..., version: _Optional[str] = ..., documentation_url: _Optional[str] = ..., capabilities: _Optional[_Union[AgentCapabilities, _Mapping]] = ..., security_schemes: _Optional[_Mapping[str, SecurityScheme]] = ..., security: _Optional[_Iterable[_Union[Security, _Mapping]]] = ..., default_input_modes: _Optional[_Iterable[str]] = ..., default_output_modes: _Optional[_Iterable[str]] = ..., skills: _Optional[_Iterable[_Union[AgentSkill, _Mapping]]] = ..., supports_authenticated_extended_card: bool = ...) -> None: ... + +class AgentProvider(_message.Message): + __slots__ = ("url", "organization") + URL_FIELD_NUMBER: _ClassVar[int] + ORGANIZATION_FIELD_NUMBER: _ClassVar[int] + url: str + organization: str + def __init__(self, url: _Optional[str] = ..., organization: _Optional[str] = ...) -> None: ... + +class AgentCapabilities(_message.Message): + __slots__ = ("streaming", "push_notifications", "extensions") + STREAMING_FIELD_NUMBER: _ClassVar[int] + PUSH_NOTIFICATIONS_FIELD_NUMBER: _ClassVar[int] + EXTENSIONS_FIELD_NUMBER: _ClassVar[int] + streaming: bool + push_notifications: bool + extensions: _containers.RepeatedCompositeFieldContainer[AgentExtension] + def __init__(self, streaming: bool = ..., push_notifications: bool = ..., extensions: _Optional[_Iterable[_Union[AgentExtension, _Mapping]]] = ...) -> None: ... + +class AgentExtension(_message.Message): + __slots__ = ("uri", "description", "required", "params") + URI_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + REQUIRED_FIELD_NUMBER: _ClassVar[int] + PARAMS_FIELD_NUMBER: _ClassVar[int] + uri: str + description: str + required: bool + params: _struct_pb2.Struct + def __init__(self, uri: _Optional[str] = ..., description: _Optional[str] = ..., required: bool = ..., params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class AgentSkill(_message.Message): + __slots__ = ("id", "name", "description", "tags", "examples", "input_modes", "output_modes") + ID_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + TAGS_FIELD_NUMBER: _ClassVar[int] + EXAMPLES_FIELD_NUMBER: _ClassVar[int] + INPUT_MODES_FIELD_NUMBER: _ClassVar[int] + OUTPUT_MODES_FIELD_NUMBER: _ClassVar[int] + id: str + name: str + description: str + tags: _containers.RepeatedScalarFieldContainer[str] + examples: _containers.RepeatedScalarFieldContainer[str] + input_modes: _containers.RepeatedScalarFieldContainer[str] + output_modes: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, id: _Optional[str] = ..., name: _Optional[str] = ..., description: _Optional[str] = ..., tags: _Optional[_Iterable[str]] = ..., examples: _Optional[_Iterable[str]] = ..., input_modes: _Optional[_Iterable[str]] = ..., output_modes: _Optional[_Iterable[str]] = ...) -> None: ... + +class TaskPushNotificationConfig(_message.Message): + __slots__ = ("name", "push_notification_config") + NAME_FIELD_NUMBER: _ClassVar[int] + PUSH_NOTIFICATION_CONFIG_FIELD_NUMBER: _ClassVar[int] + name: str + push_notification_config: PushNotificationConfig + def __init__(self, name: _Optional[str] = ..., push_notification_config: _Optional[_Union[PushNotificationConfig, _Mapping]] = ...) -> None: ... + +class StringList(_message.Message): + __slots__ = ("list",) + LIST_FIELD_NUMBER: _ClassVar[int] + list: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, list: _Optional[_Iterable[str]] = ...) -> None: ... + +class Security(_message.Message): + __slots__ = ("schemes",) + class SchemesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: StringList + def __init__(self, key: _Optional[str] = ..., value: _Optional[_Union[StringList, _Mapping]] = ...) -> None: ... + SCHEMES_FIELD_NUMBER: _ClassVar[int] + schemes: _containers.MessageMap[str, StringList] + def __init__(self, schemes: _Optional[_Mapping[str, StringList]] = ...) -> None: ... + +class SecurityScheme(_message.Message): + __slots__ = ("api_key_security_scheme", "http_auth_security_scheme", "oauth2_security_scheme", "open_id_connect_security_scheme") + API_KEY_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + HTTP_AUTH_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + OAUTH2_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + OPEN_ID_CONNECT_SECURITY_SCHEME_FIELD_NUMBER: _ClassVar[int] + api_key_security_scheme: APIKeySecurityScheme + http_auth_security_scheme: HTTPAuthSecurityScheme + oauth2_security_scheme: OAuth2SecurityScheme + open_id_connect_security_scheme: OpenIdConnectSecurityScheme + def __init__(self, api_key_security_scheme: _Optional[_Union[APIKeySecurityScheme, _Mapping]] = ..., http_auth_security_scheme: _Optional[_Union[HTTPAuthSecurityScheme, _Mapping]] = ..., oauth2_security_scheme: _Optional[_Union[OAuth2SecurityScheme, _Mapping]] = ..., open_id_connect_security_scheme: _Optional[_Union[OpenIdConnectSecurityScheme, _Mapping]] = ...) -> None: ... + +class APIKeySecurityScheme(_message.Message): + __slots__ = ("description", "location", "name") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + LOCATION_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + description: str + location: str + name: str + def __init__(self, description: _Optional[str] = ..., location: _Optional[str] = ..., name: _Optional[str] = ...) -> None: ... + +class HTTPAuthSecurityScheme(_message.Message): + __slots__ = ("description", "scheme", "bearer_format") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + SCHEME_FIELD_NUMBER: _ClassVar[int] + BEARER_FORMAT_FIELD_NUMBER: _ClassVar[int] + description: str + scheme: str + bearer_format: str + def __init__(self, description: _Optional[str] = ..., scheme: _Optional[str] = ..., bearer_format: _Optional[str] = ...) -> None: ... + +class OAuth2SecurityScheme(_message.Message): + __slots__ = ("description", "flows") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + FLOWS_FIELD_NUMBER: _ClassVar[int] + description: str + flows: OAuthFlows + def __init__(self, description: _Optional[str] = ..., flows: _Optional[_Union[OAuthFlows, _Mapping]] = ...) -> None: ... + +class OpenIdConnectSecurityScheme(_message.Message): + __slots__ = ("description", "open_id_connect_url") + DESCRIPTION_FIELD_NUMBER: _ClassVar[int] + OPEN_ID_CONNECT_URL_FIELD_NUMBER: _ClassVar[int] + description: str + open_id_connect_url: str + def __init__(self, description: _Optional[str] = ..., open_id_connect_url: _Optional[str] = ...) -> None: ... + +class OAuthFlows(_message.Message): + __slots__ = ("authorization_code", "client_credentials", "implicit", "password") + AUTHORIZATION_CODE_FIELD_NUMBER: _ClassVar[int] + CLIENT_CREDENTIALS_FIELD_NUMBER: _ClassVar[int] + IMPLICIT_FIELD_NUMBER: _ClassVar[int] + PASSWORD_FIELD_NUMBER: _ClassVar[int] + authorization_code: AuthorizationCodeOAuthFlow + client_credentials: ClientCredentialsOAuthFlow + implicit: ImplicitOAuthFlow + password: PasswordOAuthFlow + def __init__(self, authorization_code: _Optional[_Union[AuthorizationCodeOAuthFlow, _Mapping]] = ..., client_credentials: _Optional[_Union[ClientCredentialsOAuthFlow, _Mapping]] = ..., implicit: _Optional[_Union[ImplicitOAuthFlow, _Mapping]] = ..., password: _Optional[_Union[PasswordOAuthFlow, _Mapping]] = ...) -> None: ... + +class AuthorizationCodeOAuthFlow(_message.Message): + __slots__ = ("authorization_url", "token_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + AUTHORIZATION_URL_FIELD_NUMBER: _ClassVar[int] + TOKEN_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + authorization_url: str + token_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, authorization_url: _Optional[str] = ..., token_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class ClientCredentialsOAuthFlow(_message.Message): + __slots__ = ("token_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + TOKEN_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + token_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, token_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class ImplicitOAuthFlow(_message.Message): + __slots__ = ("authorization_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + AUTHORIZATION_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + authorization_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, authorization_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class PasswordOAuthFlow(_message.Message): + __slots__ = ("token_url", "refresh_url", "scopes") + class ScopesEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + TOKEN_URL_FIELD_NUMBER: _ClassVar[int] + REFRESH_URL_FIELD_NUMBER: _ClassVar[int] + SCOPES_FIELD_NUMBER: _ClassVar[int] + token_url: str + refresh_url: str + scopes: _containers.ScalarMap[str, str] + def __init__(self, token_url: _Optional[str] = ..., refresh_url: _Optional[str] = ..., scopes: _Optional[_Mapping[str, str]] = ...) -> None: ... + +class SendMessageRequest(_message.Message): + __slots__ = ("request", "configuration", "metadata") + REQUEST_FIELD_NUMBER: _ClassVar[int] + CONFIGURATION_FIELD_NUMBER: _ClassVar[int] + METADATA_FIELD_NUMBER: _ClassVar[int] + request: Message + configuration: SendMessageConfiguration + metadata: _struct_pb2.Struct + def __init__(self, request: _Optional[_Union[Message, _Mapping]] = ..., configuration: _Optional[_Union[SendMessageConfiguration, _Mapping]] = ..., metadata: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class GetTaskRequest(_message.Message): + __slots__ = ("name", "history_length") + NAME_FIELD_NUMBER: _ClassVar[int] + HISTORY_LENGTH_FIELD_NUMBER: _ClassVar[int] + name: str + history_length: int + def __init__(self, name: _Optional[str] = ..., history_length: _Optional[int] = ...) -> None: ... + +class CancelTaskRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class GetTaskPushNotificationRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class CreateTaskPushNotificationRequest(_message.Message): + __slots__ = ("parent", "config_id", "config") + PARENT_FIELD_NUMBER: _ClassVar[int] + CONFIG_ID_FIELD_NUMBER: _ClassVar[int] + CONFIG_FIELD_NUMBER: _ClassVar[int] + parent: str + config_id: str + config: TaskPushNotificationConfig + def __init__(self, parent: _Optional[str] = ..., config_id: _Optional[str] = ..., config: _Optional[_Union[TaskPushNotificationConfig, _Mapping]] = ...) -> None: ... + +class TaskSubscriptionRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class ListTaskPushNotificationRequest(_message.Message): + __slots__ = ("parent", "page_size", "page_token") + PARENT_FIELD_NUMBER: _ClassVar[int] + PAGE_SIZE_FIELD_NUMBER: _ClassVar[int] + PAGE_TOKEN_FIELD_NUMBER: _ClassVar[int] + parent: str + page_size: int + page_token: str + def __init__(self, parent: _Optional[str] = ..., page_size: _Optional[int] = ..., page_token: _Optional[str] = ...) -> None: ... + +class GetAgentCardRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class SendMessageResponse(_message.Message): + __slots__ = ("task", "msg") + TASK_FIELD_NUMBER: _ClassVar[int] + MSG_FIELD_NUMBER: _ClassVar[int] + task: Task + msg: Message + def __init__(self, task: _Optional[_Union[Task, _Mapping]] = ..., msg: _Optional[_Union[Message, _Mapping]] = ...) -> None: ... + +class StreamResponse(_message.Message): + __slots__ = ("task", "msg", "status_update", "artifact_update") + TASK_FIELD_NUMBER: _ClassVar[int] + MSG_FIELD_NUMBER: _ClassVar[int] + STATUS_UPDATE_FIELD_NUMBER: _ClassVar[int] + ARTIFACT_UPDATE_FIELD_NUMBER: _ClassVar[int] + task: Task + msg: Message + status_update: TaskStatusUpdateEvent + artifact_update: TaskArtifactUpdateEvent + def __init__(self, task: _Optional[_Union[Task, _Mapping]] = ..., msg: _Optional[_Union[Message, _Mapping]] = ..., status_update: _Optional[_Union[TaskStatusUpdateEvent, _Mapping]] = ..., artifact_update: _Optional[_Union[TaskArtifactUpdateEvent, _Mapping]] = ...) -> None: ... + +class ListTaskPushNotificationResponse(_message.Message): + __slots__ = ("configs", "next_page_token") + CONFIGS_FIELD_NUMBER: _ClassVar[int] + NEXT_PAGE_TOKEN_FIELD_NUMBER: _ClassVar[int] + configs: _containers.RepeatedCompositeFieldContainer[TaskPushNotificationConfig] + next_page_token: str + def __init__(self, configs: _Optional[_Iterable[_Union[TaskPushNotificationConfig, _Mapping]]] = ..., next_page_token: _Optional[str] = ...) -> None: ... diff --git a/src/a2a/grpc/a2a_pb2_grpc.py b/src/a2a/grpc/a2a_pb2_grpc.py new file mode 100644 index 00000000..01a28373 --- /dev/null +++ b/src/a2a/grpc/a2a_pb2_grpc.py @@ -0,0 +1,478 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import a2a_pb2 as a2a__pb2 + + +class A2AServiceStub(object): + """A2AService defines the gRPC version of the A2A protocol. This has a slightly + different shape than the JSONRPC version to better conform to AIP-127, + where appropriate. The nouns are AgentCard, Message, Task and + TaskPushNotification. + - Messages are not a standard resource so there is no get/delete/update/list + interface, only a send and stream custom methods. + - Tasks have a get interface and custom cancel and subscribe methods. + - TaskPushNotification are a resource whose parent is a task. They have get, + list and create methods. + - AgentCard is a static resource with only a get method. + Of particular note of deviation from JSONRPC approach the request metadata + fields are not present as they don't comply with AIP rules, and the + optional history_length on the get task method is not present as it also + violates AIP-127 and AIP-131. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendMessage = channel.unary_unary( + '/a2a.v1.A2AService/SendMessage', + request_serializer=a2a__pb2.SendMessageRequest.SerializeToString, + response_deserializer=a2a__pb2.SendMessageResponse.FromString, + _registered_method=True) + self.SendStreamingMessage = channel.unary_stream( + '/a2a.v1.A2AService/SendStreamingMessage', + request_serializer=a2a__pb2.SendMessageRequest.SerializeToString, + response_deserializer=a2a__pb2.StreamResponse.FromString, + _registered_method=True) + self.GetTask = channel.unary_unary( + '/a2a.v1.A2AService/GetTask', + request_serializer=a2a__pb2.GetTaskRequest.SerializeToString, + response_deserializer=a2a__pb2.Task.FromString, + _registered_method=True) + self.CancelTask = channel.unary_unary( + '/a2a.v1.A2AService/CancelTask', + request_serializer=a2a__pb2.CancelTaskRequest.SerializeToString, + response_deserializer=a2a__pb2.Task.FromString, + _registered_method=True) + self.TaskSubscription = channel.unary_stream( + '/a2a.v1.A2AService/TaskSubscription', + request_serializer=a2a__pb2.TaskSubscriptionRequest.SerializeToString, + response_deserializer=a2a__pb2.StreamResponse.FromString, + _registered_method=True) + self.CreateTaskPushNotification = channel.unary_unary( + '/a2a.v1.A2AService/CreateTaskPushNotification', + request_serializer=a2a__pb2.CreateTaskPushNotificationRequest.SerializeToString, + response_deserializer=a2a__pb2.TaskPushNotificationConfig.FromString, + _registered_method=True) + self.GetTaskPushNotification = channel.unary_unary( + '/a2a.v1.A2AService/GetTaskPushNotification', + request_serializer=a2a__pb2.GetTaskPushNotificationRequest.SerializeToString, + response_deserializer=a2a__pb2.TaskPushNotificationConfig.FromString, + _registered_method=True) + self.ListTaskPushNotification = channel.unary_unary( + '/a2a.v1.A2AService/ListTaskPushNotification', + request_serializer=a2a__pb2.ListTaskPushNotificationRequest.SerializeToString, + response_deserializer=a2a__pb2.ListTaskPushNotificationResponse.FromString, + _registered_method=True) + self.GetAgentCard = channel.unary_unary( + '/a2a.v1.A2AService/GetAgentCard', + request_serializer=a2a__pb2.GetAgentCardRequest.SerializeToString, + response_deserializer=a2a__pb2.AgentCard.FromString, + _registered_method=True) + + +class A2AServiceServicer(object): + """A2AService defines the gRPC version of the A2A protocol. This has a slightly + different shape than the JSONRPC version to better conform to AIP-127, + where appropriate. The nouns are AgentCard, Message, Task and + TaskPushNotification. + - Messages are not a standard resource so there is no get/delete/update/list + interface, only a send and stream custom methods. + - Tasks have a get interface and custom cancel and subscribe methods. + - TaskPushNotification are a resource whose parent is a task. They have get, + list and create methods. + - AgentCard is a static resource with only a get method. + Of particular note of deviation from JSONRPC approach the request metadata + fields are not present as they don't comply with AIP rules, and the + optional history_length on the get task method is not present as it also + violates AIP-127 and AIP-131. + """ + + def SendMessage(self, request, context): + """Send a message to the agent. This is a blocking call that will return the + task once it is completed, or a LRO if requested. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendStreamingMessage(self, request, context): + """SendStreamingMessage is a streaming call that will return a stream of + task update events until the Task is in an interrupted or terminal state. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetTask(self, request, context): + """Get the current state of a task from the agent. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CancelTask(self, request, context): + """Cancel a task from the agent. If supported one should expect no + more task updates for the task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TaskSubscription(self, request, context): + """TaskSubscription is a streaming call that will return a stream of task + update events. This attaches the stream to an existing in process task. + If the task is complete the stream will return the completed task (like + GetTask) and close the stream. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def CreateTaskPushNotification(self, request, context): + """Set a push notification config for a task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetTaskPushNotification(self, request, context): + """Get a push notification config for a task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ListTaskPushNotification(self, request, context): + """Get a list of push notifications configured for a task. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetAgentCard(self, request, context): + """GetAgentCard returns the agent card for the agent. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_A2AServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendMessage': grpc.unary_unary_rpc_method_handler( + servicer.SendMessage, + request_deserializer=a2a__pb2.SendMessageRequest.FromString, + response_serializer=a2a__pb2.SendMessageResponse.SerializeToString, + ), + 'SendStreamingMessage': grpc.unary_stream_rpc_method_handler( + servicer.SendStreamingMessage, + request_deserializer=a2a__pb2.SendMessageRequest.FromString, + response_serializer=a2a__pb2.StreamResponse.SerializeToString, + ), + 'GetTask': grpc.unary_unary_rpc_method_handler( + servicer.GetTask, + request_deserializer=a2a__pb2.GetTaskRequest.FromString, + response_serializer=a2a__pb2.Task.SerializeToString, + ), + 'CancelTask': grpc.unary_unary_rpc_method_handler( + servicer.CancelTask, + request_deserializer=a2a__pb2.CancelTaskRequest.FromString, + response_serializer=a2a__pb2.Task.SerializeToString, + ), + 'TaskSubscription': grpc.unary_stream_rpc_method_handler( + servicer.TaskSubscription, + request_deserializer=a2a__pb2.TaskSubscriptionRequest.FromString, + response_serializer=a2a__pb2.StreamResponse.SerializeToString, + ), + 'CreateTaskPushNotification': grpc.unary_unary_rpc_method_handler( + servicer.CreateTaskPushNotification, + request_deserializer=a2a__pb2.CreateTaskPushNotificationRequest.FromString, + response_serializer=a2a__pb2.TaskPushNotificationConfig.SerializeToString, + ), + 'GetTaskPushNotification': grpc.unary_unary_rpc_method_handler( + servicer.GetTaskPushNotification, + request_deserializer=a2a__pb2.GetTaskPushNotificationRequest.FromString, + response_serializer=a2a__pb2.TaskPushNotificationConfig.SerializeToString, + ), + 'ListTaskPushNotification': grpc.unary_unary_rpc_method_handler( + servicer.ListTaskPushNotification, + request_deserializer=a2a__pb2.ListTaskPushNotificationRequest.FromString, + response_serializer=a2a__pb2.ListTaskPushNotificationResponse.SerializeToString, + ), + 'GetAgentCard': grpc.unary_unary_rpc_method_handler( + servicer.GetAgentCard, + request_deserializer=a2a__pb2.GetAgentCardRequest.FromString, + response_serializer=a2a__pb2.AgentCard.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'a2a.v1.A2AService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('a2a.v1.A2AService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class A2AService(object): + """A2AService defines the gRPC version of the A2A protocol. This has a slightly + different shape than the JSONRPC version to better conform to AIP-127, + where appropriate. The nouns are AgentCard, Message, Task and + TaskPushNotification. + - Messages are not a standard resource so there is no get/delete/update/list + interface, only a send and stream custom methods. + - Tasks have a get interface and custom cancel and subscribe methods. + - TaskPushNotification are a resource whose parent is a task. They have get, + list and create methods. + - AgentCard is a static resource with only a get method. + Of particular note of deviation from JSONRPC approach the request metadata + fields are not present as they don't comply with AIP rules, and the + optional history_length on the get task method is not present as it also + violates AIP-127 and AIP-131. + """ + + @staticmethod + def SendMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/SendMessage', + a2a__pb2.SendMessageRequest.SerializeToString, + a2a__pb2.SendMessageResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendStreamingMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/a2a.v1.A2AService/SendStreamingMessage', + a2a__pb2.SendMessageRequest.SerializeToString, + a2a__pb2.StreamResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetTask(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/GetTask', + a2a__pb2.GetTaskRequest.SerializeToString, + a2a__pb2.Task.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def CancelTask(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/CancelTask', + a2a__pb2.CancelTaskRequest.SerializeToString, + a2a__pb2.Task.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def TaskSubscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/a2a.v1.A2AService/TaskSubscription', + a2a__pb2.TaskSubscriptionRequest.SerializeToString, + a2a__pb2.StreamResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def CreateTaskPushNotification(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/CreateTaskPushNotification', + a2a__pb2.CreateTaskPushNotificationRequest.SerializeToString, + a2a__pb2.TaskPushNotificationConfig.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetTaskPushNotification(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/GetTaskPushNotification', + a2a__pb2.GetTaskPushNotificationRequest.SerializeToString, + a2a__pb2.TaskPushNotificationConfig.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def ListTaskPushNotification(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/ListTaskPushNotification', + a2a__pb2.ListTaskPushNotificationRequest.SerializeToString, + a2a__pb2.ListTaskPushNotificationResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetAgentCard(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/a2a.v1.A2AService/GetAgentCard', + a2a__pb2.GetAgentCardRequest.SerializeToString, + a2a__pb2.AgentCard.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index f0d2667d..8cf2fe8c 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -3,6 +3,7 @@ from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) +from a2a.server.request_handlers.grpc_handler import GrpcHandler from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( @@ -13,6 +14,7 @@ __all__ = [ 'DefaultRequestHandler', + 'GrpcHandler', 'JSONRPCHandler', 'RequestHandler', 'build_error_response', diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 97d90fe1..2c81da96 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,5 +1,6 @@ import asyncio import logging +import uuid from collections.abc import AsyncGenerator from typing import cast @@ -364,6 +365,8 @@ async def on_set_task_push_notification_config( if not task: raise ServerError(error=TaskNotFoundError()) + # Generate a unique id for the notification + params.pushNotificationConfig.id = str(uuid.uuid4()) await self._push_notifier.set_info( params.taskId, params.pushNotificationConfig, diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py new file mode 100644 index 00000000..b8c21070 --- /dev/null +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -0,0 +1,367 @@ +# ruff: noqa: N802 +import contextlib +import logging + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterable + +import grpc + +import a2a.grpc.a2a_pb2_grpc as a2a_grpc + +from a2a import types +from a2a.auth.user import UnauthenticatedUser +from a2a.grpc import a2a_pb2 +from a2a.server.context import ServerCallContext +from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.types import ( + AgentCard, + TaskNotFoundError, +) +from a2a.utils import proto_utils +from a2a.utils.errors import ServerError +from a2a.utils.helpers import validate, validate_async_generator + + +logger = logging.getLogger(__name__) + +# For now we use a trivial wrapper on the grpc context object + + +class CallContextBuilder(ABC): + """A class for building ServerCallContexts using the Starlette Request.""" + + @abstractmethod + def build(self, context: grpc.ServicerContext) -> ServerCallContext: + """Builds a ServerCallContext from a gRPC Request.""" + + +class DefaultCallContextBuilder(CallContextBuilder): + """A default implementation of CallContextBuilder.""" + + def build(self, context: grpc.ServicerContext) -> ServerCallContext: + """Builds the ServerCallContext.""" + user = UnauthenticatedUser() + state = {} + with contextlib.suppress(Exception): + state['grpc_context'] = context + return ServerCallContext(user=user, state=state) + + +class GrpcHandler(a2a_grpc.A2AServiceServicer): + """Maps incoming gRPC requests to the appropriate request handler method.""" + + def __init__( + self, + agent_card: AgentCard, + request_handler: RequestHandler, + context_builder: CallContextBuilder | None = None + ): + """Initializes the GrpcHandler. + + Args: + agent_card: The AgentCard describing the agent's capabilities. + request_handler: The underlying `RequestHandler` instance to + delegate requests to. + context_builder: The CallContextBuilder object. If none the + DefaultCallContextBuilder is used. + """ + self.agent_card = agent_card + self.request_handler = request_handler + self.context_builder = context_builder or DefaultCallContextBuilder() + + async def SendMessage( + self, + request: a2a_pb2.SendMessageRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.SendMessageResponse: + """Handles the 'SendMessage' gRPC method. + + Args: + request: The incoming `SendMessageRequest` object. + context: Context provided by the server. + + Returns: + A `SendMessageResponse` object containing the result (Task or + Message) or throws an error response if a `ServerError` is raised + by the handler. + """ + try: + # Construct the server context object + server_context = self.context_builder.build(context) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + request, + ) + task_or_message = await self.request_handler.on_message_send( + a2a_request, server_context + ) + return proto_utils.ToProto.task_or_message(task_or_message) + except ServerError as e: + await self.abort_context(e, context) + return a2a_pb2.SendMessageResponse() + + @validate_async_generator( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def SendStreamingMessage( + self, + request: a2a_pb2.SendMessageRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterable[a2a_pb2.StreamResponse]: + """Handles the 'StreamMessage' gRPC method. + + Yields response objects as they are produced by the underlying handler's + stream. + + Args: + request: The incoming `SendMessageRequest` object. + context: Context provided by the server. + + Yields: + `StreamResponse` objects containing streaming events + (Task, Message, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) + or gRPC error responses if a `ServerError` is raised. + """ + server_context = self.context_builder.build(context) + # Transform the proto object to the python internal objects + a2a_request = proto_utils.FromProto.message_send_params( + request, + ) + try: + async for event in self.request_handler.on_message_send_stream( + a2a_request, server_context + ): + yield proto_utils.ToProto.stream_response(event) + except ServerError as e: + await self.abort_context(e, context) + return + + async def CancelTask( + self, + request: a2a_pb2.CancelTaskRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.Task: + """Handles the 'CancelTask' gRPC method. + + Args: + request: The incoming `CancelTaskRequest` object. + context: Context provided by the server. + + Returns: + A `Task` object containing the updated Task or a gRPC error. + """ + try: + server_context = self.context_builder.build(context) + task_id_params = proto_utils.FromProto.task_id_params(request) + task = await self.request_handler.on_cancel_task( + task_id_params, server_context + ) + if task: + return proto_utils.ToProto.task(task) + await self.abort_context( + ServerError(error=TaskNotFoundError()), context + ) + except ServerError as e: + await self.abort_context(e, context) + return a2a_pb2.Task() + + @validate_async_generator( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) + async def TaskSubscription( + self, + request: a2a_pb2.TaskSubscriptionRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterable[a2a_pb2.StreamResponse]: + """Handles the 'TaskSubscription' gRPC method. + + Yields response objects as they are produced by the underlying handler's + stream. + + Args: + request: The incoming `TaskSubscriptionRequest` object. + context: Context provided by the server. + + Yields: + `StreamResponse` objects containing streaming events + """ + try: + server_context = self.context_builder.build(context) + async for event in self.request_handler.on_resubscribe_to_task( + proto_utils.FromProto.task_id_params(request), + server_context, + ): + yield proto_utils.ToProto.stream_response(event) + except ServerError as e: + await self.abort_context(e, context) + + async def GetTaskPushNotification( + self, + request: a2a_pb2.GetTaskPushNotificationRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.TaskPushNotificationConfig: + """Handles the 'GetTaskPushNotification' gRPC method. + + Args: + request: The incoming `GetTaskPushNotificationConfigRequest` object. + context: Context provided by the server. + + Returns: + A `TaskPushNotificationConfig` object containing the config. + """ + try: + server_context = self.context_builder.build(context) + config = ( + await self.request_handler.on_get_task_push_notification_config( + proto_utils.FromProto.task_id_params(request), + server_context, + ) + ) + return proto_utils.ToProto.task_push_notification_config(config) + except ServerError as e: + await self.abort_context(e, context) + return a2a_pb2.TaskPushNotificationConfig() + + @validate( + lambda self: self.agent_card.capabilities.pushNotifications, + 'Push notifications are not supported by the agent', + ) + async def CreateTaskPushNotification( + self, + request: a2a_pb2.CreateTaskPushNotificationRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.TaskPushNotificationConfig: + """Handles the 'CreateTaskPushNotification' gRPC method. + + Requires the agent to support push notifications. + + Args: + request: The incoming `CreateTaskPushNotificationRequest` object. + context: Context provided by the server. + + Returns: + A `TaskPushNotificationConfig` object + + Raises: + ServerError: If push notifications are not supported by the agent + (due to the `@validate` decorator). + """ + try: + server_context = self.context_builder.build(context) + config = ( + await self.request_handler.on_set_task_push_notification_config( + proto_utils.FromProto.task_push_notification_config( + request, + ), + server_context, + ) + ) + return proto_utils.ToProto.task_push_notification_config(config) + except ServerError as e: + await self.abort_context(e, context) + return a2a_pb2.TaskPushNotificationConfig() + + async def GetTask( + self, + request: a2a_pb2.GetTaskRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.Task: + """Handles the 'GetTask' gRPC method. + + Args: + request: The incoming `GetTaskRequest` object. + context: Context provided by the server. + + Returns: + A `Task` object. + """ + try: + server_context = self.context_builder.build(context) + task = await self.request_handler.on_get_task( + proto_utils.FromProto.task_query_params(request), server_context + ) + if task: + return proto_utils.ToProto.task(task) + await self.abort_context( + ServerError(error=TaskNotFoundError()), context + ) + except ServerError as e: + await self.abort_context(e, context) + return a2a_pb2.Task() + + async def GetAgentCard( + self, + request: a2a_pb2.GetAgentCardRequest, + context: grpc.aio.ServicerContext, + ) -> a2a_pb2.AgentCard: + """Get the agent card for the agent served.""" + return proto_utils.ToProto.agent_card(self.agent_card) + + async def abort_context( + self, error: ServerError, context: grpc.ServicerContext + ) -> None: + """Sets the grpc errors appropriately in the context.""" + match error.error: + case types.JSONParseError(): + await context.abort( + grpc.StatusCode.INTERNAL, + f'JSONParseError: {error.error.message}', + ) + case types.InvalidRequestError(): + await context.abort( + grpc.StatusCode.INVALID_ARGUMENT, + f'InvalidRequestError: {error.error.message}', + ) + case types.MethodNotFoundError(): + await context.abort( + grpc.StatusCode.NOT_FOUND, + f'MethodNotFoundError: {error.error.message}', + ) + case types.InvalidParamsError(): + await context.abort( + grpc.StatusCode.INVALID_ARGUMENT, + f'InvalidParamsError: {error.error.message}', + ) + case types.InternalError(): + await context.abort( + grpc.StatusCode.INTERNAL, + f'InternalError: {error.error.message}', + ) + case types.TaskNotFoundError(): + await context.abort( + grpc.StatusCode.NOT_FOUND, + f'TaskNotFoundError: {error.error.message}', + ) + case types.TaskNotCancelableError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'TaskNotCancelableError: {error.error.message}', + ) + case types.PushNotificationNotSupportedError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'PushNotificationNotSupportedError: {error.error.message}', + ) + case types.UnsupportedOperationError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'UnsupportedOperationError: {error.error.message}', + ) + case types.ContentTypeNotSupportedError(): + await context.abort( + grpc.StatusCode.UNIMPLEMENTED, + f'ContentTypeNotSupportedError: {error.error.message}', + ) + case types.InvalidAgentResponseError(): + await context.abort( + grpc.StatusCode.INTERNAL, + f'InvalidAgentResponseError: {error.error.message}', + ) + case _: + await context.abort( + grpc.StatusCode.UNKNOWN, + f'Unknown error type: {error.error}', + ) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 40c6fae2..a1cc43ec 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -1,5 +1,6 @@ """General utility functions for the A2A Python SDK.""" +import functools import logging from collections.abc import Callable @@ -148,6 +149,39 @@ def wrapper(self: Any, *args, **kwargs) -> Any: return decorator +def validate_async_generator( + expression: Callable[[Any], bool], error_message: str | None = None +): + """Decorator that validates if a given expression evaluates to True. + + Typically used on class methods to check capabilities or configuration + before executing the method's logic. If the expression is False, + a `ServerError` with an `UnsupportedOperationError` is raised. + + Args: + expression: A callable that takes the instance (`self`) as its argument + and returns a boolean. + error_message: An optional custom error message for the `UnsupportedOperationError`. + If None, the string representation of the expression will be used. + """ + + def decorator(function): + @functools.wraps(function) + async def wrapper(self, *args, **kwargs): + if not expression(self): + final_message = error_message or str(expression) + logger.error(f'Unsupported Operation: {final_message}') + raise ServerError( + UnsupportedOperationError(message=final_message) + ) + async for i in function(self, *args, **kwargs): + yield i + + return wrapper + + return decorator + + def are_modalities_compatible( server_output_modes: list[str] | None, client_output_modes: list[str] | None ) -> bool: diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py new file mode 100644 index 00000000..e1dddc39 --- /dev/null +++ b/src/a2a/utils/proto_utils.py @@ -0,0 +1,803 @@ +# mypy: disable-error-code="arg-type" +"""Utils for converting between proto and Python types.""" + +import json +import re + +from typing import Any + +from google.protobuf import json_format, struct_pb2 + +from a2a import types +from a2a.grpc import a2a_pb2 +from a2a.utils.errors import ServerError + + +# Regexp patterns for matching +_TASK_NAME_MATCH = r'tasks/(\w+)' +_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotifications/(\w+)' + + +class ToProto: + """Converts Python types to proto types.""" + + @classmethod + def message(cls, message: types.Message | None) -> a2a_pb2.Message | None: + if message is None: + return None + return a2a_pb2.Message( + message_id=message.messageId, + content=[ToProto.part(p) for p in message.parts], + context_id=message.contextId, + task_id=message.taskId, + role=cls.role(message.role), + metadata=ToProto.metadata(message.metadata), + ) + + @classmethod + def metadata( + cls, metadata: dict[str, Any] | None + ) -> struct_pb2.Struct | None: + if metadata is None: + return None + return struct_pb2.Struct( + # TODO: Add support for other types. + fields={ + key: struct_pb2.Value(string_value=value) + for key, value in metadata.items() + if isinstance(value, str) + } + ) + + @classmethod + def part(cls, part: types.Part) -> a2a_pb2.Part: + if isinstance(part.root, types.TextPart): + return a2a_pb2.Part(text=part.root.text) + if isinstance(part.root, types.FilePart): + return a2a_pb2.Part(file=ToProto.file(part.root.file)) + if isinstance(part.root, types.DataPart): + return a2a_pb2.Part(data=ToProto.data(part.root.data)) + raise ValueError(f'Unsupported part type: {part.root}') + + @classmethod + def data(cls, data: dict[str, Any]) -> a2a_pb2.DataPart: + json_data = json.dumps(data) + return a2a_pb2.DataPart( + data=json_format.Parse( + json_data, + struct_pb2.Struct(), + ) + ) + + @classmethod + def file( + cls, file: types.FileWithUri | types.FileWithBytes + ) -> a2a_pb2.FilePart: + if isinstance(file, types.FileWithUri): + return a2a_pb2.FilePart(file_with_uri=file.uri) + return a2a_pb2.FilePart(file_with_bytes=file.bytes.encode('utf-8')) + + @classmethod + def task(cls, task: types.Task) -> a2a_pb2.Task: + return a2a_pb2.Task( + id=task.id, + context_id=task.contextId, + status=ToProto.task_status(task.status), + artifacts=( + [ToProto.artifact(a) for a in task.artifacts] + if task.artifacts + else None + ), + history=( + [ToProto.message(h) for h in task.history] # type: ignore[misc] + if task.history + else None + ), + ) + + @classmethod + def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus: + return a2a_pb2.TaskStatus( + state=ToProto.task_state(status.state), + update=ToProto.message(status.message), + ) + + @classmethod + def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState: + match state: + case types.TaskState.submitted: + return a2a_pb2.TaskState.TASK_STATE_SUBMITTED + case types.TaskState.working: + return a2a_pb2.TaskState.TASK_STATE_WORKING + case types.TaskState.completed: + return a2a_pb2.TaskState.TASK_STATE_COMPLETED + case types.TaskState.canceled: + return a2a_pb2.TaskState.TASK_STATE_CANCELLED + case types.TaskState.failed: + return a2a_pb2.TaskState.TASK_STATE_FAILED + case types.TaskState.input_required: + return a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED + case _: + return a2a_pb2.TaskState.TASK_STATE_UNSPECIFIED + + @classmethod + def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact: + return a2a_pb2.Artifact( + artifact_id=artifact.artifactId, + description=artifact.description, + metadata=ToProto.metadata(artifact.metadata), + name=artifact.name, + parts=[ToProto.part(p) for p in artifact.parts], + ) + + @classmethod + def authentication_info( + cls, info: types.PushNotificationAuthenticationInfo + ) -> a2a_pb2.AuthenticationInfo: + return a2a_pb2.AuthenticationInfo( + schemes=info.schemes, + credentials=info.credentials, + ) + + @classmethod + def push_notification_config( + cls, config: types.PushNotificationConfig + ) -> a2a_pb2.PushNotificationConfig: + return a2a_pb2.PushNotificationConfig( + id=config.id or '', + url=config.url, + token=config.token, + authentication=ToProto.authentication_info(config.authentication), + ) + + @classmethod + def task_artifact_update_event( + cls, event: types.TaskArtifactUpdateEvent + ) -> a2a_pb2.TaskArtifactUpdateEvent: + return a2a_pb2.TaskArtifactUpdateEvent( + task_id=event.taskId, + context_id=event.contextId, + artifact=ToProto.artifact(event.artifact), + metadata=ToProto.metadata(event.metadata), + append=event.append or False, + last_chunk=event.lastChunk or False, + ) + + @classmethod + def task_status_update_event( + cls, event: types.TaskStatusUpdateEvent + ) -> a2a_pb2.TaskStatusUpdateEvent: + return a2a_pb2.TaskStatusUpdateEvent( + task_id=event.taskId, + context_id=event.contextId, + status=ToProto.task_status(event.status), + metadata=ToProto.metadata(event.metadata), + final=event.final, + ) + + @classmethod + def message_send_configuration( + cls, config: types.MessageSendConfiguration | None + ) -> a2a_pb2.SendMessageConfiguration: + if not config: + return a2a_pb2.SendMessageConfiguration() + return a2a_pb2.SendMessageConfiguration( + accepted_output_modes=list(config.acceptedOutputModes), + push_notification=ToProto.push_notification_config( + config.pushNotificationConfig + ), + history_length=config.historyLength, + blocking=config.blocking or False, + ) + + @classmethod + def update_event( + cls, + event: types.Task + | types.Message + | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent, + ) -> a2a_pb2.StreamResponse: + """Converts a task, message, or task update event to a StreamResponse.""" + if isinstance(event, types.TaskStatusUpdateEvent): + return a2a_pb2.StreamResponse( + status_update=ToProto.task_status_update_event(event) + ) + if isinstance(event, types.TaskArtifactUpdateEvent): + return a2a_pb2.StreamResponse( + artifact_update=ToProto.task_artifact_update_event(event) + ) + if isinstance(event, types.Message): + return a2a_pb2.StreamResponse(msg=ToProto.message(event)) + if isinstance(event, types.Task): + return a2a_pb2.StreamResponse(task=ToProto.task(event)) + raise ValueError(f'Unsupported event type: {type(event)}') + + @classmethod + def task_or_message( + cls, event: types.Task | types.Message + ) -> a2a_pb2.SendMessageResponse: + if isinstance(event, types.Message): + return a2a_pb2.SendMessageResponse( + msg=cls.message(event), + ) + return a2a_pb2.SendMessageResponse( + task=cls.task(event), + ) + + @classmethod + def stream_response( + cls, + event: ( + types.Message + | types.Task + | types.TaskStatusUpdateEvent + | types.TaskArtifactUpdateEvent + ), + ) -> a2a_pb2.StreamResponse: + if isinstance(event, types.Message): + return a2a_pb2.StreamResponse(msg=cls.message(event)) + if isinstance(event, types.Task): + return a2a_pb2.StreamResponse(task=cls.task(event)) + if isinstance(event, types.TaskStatusUpdateEvent): + return a2a_pb2.StreamResponse( + status_update=cls.task_status_update_event(event), + ) + return a2a_pb2.StreamResponse( + artifact_update=cls.task_artifact_update_event(event), + ) + + @classmethod + def task_push_notification_config( + cls, config: types.TaskPushNotificationConfig + ) -> a2a_pb2.TaskPushNotificationConfig: + return a2a_pb2.TaskPushNotificationConfig( + name=f'tasks/{config.taskId}/pushNotifications/{config.taskId}', + push_notification_config=cls.push_notification_config( + config.pushNotificationConfig, + ), + ) + + @classmethod + def agent_card( + cls, + card: types.AgentCard, + ) -> a2a_pb2.AgentCard: + return a2a_pb2.AgentCard( + capabilities=cls.capabilities(card.capabilities), + default_input_modes=list(card.defaultInputModes), + default_output_modes=list(card.defaultOutputModes), + description=card.description, + documentation_url=card.documentationUrl, + name=card.name, + provider=cls.provider(card.provider), + security=cls.security(card.security), + security_schemes=cls.security_schemes(card.securitySchemes), + skills=[cls.skill(x) for x in card.skills] if card.skills else [], + url=card.url, + version=card.version, + supports_authenticated_extended_card=card.supportsAuthenticatedExtendedCard, + ) + + @classmethod + def capabilities( + cls, capabilities: types.AgentCapabilities + ) -> a2a_pb2.AgentCapabilities: + return a2a_pb2.AgentCapabilities( + streaming=capabilities.streaming, + push_notifications=capabilities.pushNotifications, + ) + + @classmethod + def provider( + cls, provider: types.AgentProvider | None + ) -> a2a_pb2.AgentProvider | None: + if not provider: + return None + return a2a_pb2.AgentProvider( + organization=provider.organization, + url=provider.url, + ) + + @classmethod + def security( + cls, + security: list[dict[str, list[str]]] | None, + ) -> list[a2a_pb2.Security] | None: + if not security: + return None + rval: list[a2a_pb2.Security] = [] + for s in security: + rval.append( + a2a_pb2.Security( + schemes={ + k: a2a_pb2.StringList(list=v) for (k, v) in s.items() + } + ) + ) + return rval + + @classmethod + def security_schemes( + cls, + schemes: dict[str, types.SecurityScheme] | None, + ) -> dict[str, a2a_pb2.SecurityScheme] | None: + if not schemes: + return None + return {k: cls.security_scheme(v) for (k, v) in schemes.items()} + + @classmethod + def security_scheme( + cls, + scheme: types.SecurityScheme, + ) -> a2a_pb2.SecurityScheme: + if isinstance(scheme.root, types.APIKeySecurityScheme): + return a2a_pb2.SecurityScheme( + api_key_security_scheme=a2a_pb2.APIKeySecurityScheme( + description=scheme.root.description, + location=scheme.root.in_, + name=scheme.root.name, + ) + ) + if isinstance(scheme.root, types.HTTPAuthSecurityScheme): + return a2a_pb2.SecurityScheme( + http_auth_security_scheme=a2a_pb2.HTTPAuthSecurityScheme( + description=scheme.root.description, + scheme=scheme.root.scheme, + bearer_format=scheme.root.bearerFormat, + ) + ) + if isinstance(scheme.root, types.OAuth2SecurityScheme): + return a2a_pb2.SecurityScheme( + oauth2_security_scheme=a2a_pb2.OAuth2SecurityScheme( + description=scheme.root.description, + flows=cls.oauth2_flows(scheme.root.flows), + ) + ) + return a2a_pb2.SecurityScheme( + open_id_connect_security_scheme=a2a_pb2.OpenIdConnectSecurityScheme( + description=scheme.root.description, + open_id_connect_url=scheme.root.openIdConnectUrl, + ) + ) + + @classmethod + def oauth2_flows(cls, flows: types.OAuthFlows) -> a2a_pb2.OAuthFlows: + if flows.authorizationCode: + return a2a_pb2.OAuthFlows( + authorization_code=a2a_pb2.AuthorizationCodeOAuthFlow( + authorization_url=flows.authorizationCode.authorizationUrl, + refresh_url=flows.authorizationCode.refreshUrl, + scopes=dict(flows.authorizationCode.scopes.items()), + token_url=flows.authorizationCode.tokenUrl, + ), + ) + if flows.clientCredentials: + return a2a_pb2.OAuthFlows( + client_credentials=a2a_pb2.ClientCredentialsOAuthFlow( + refresh_url=flows.clientCredentials.refreshUrl, + scopes=dict(flows.clientCredentials.scopes.items()), + token_url=flows.clientCredentials.tokenUrl, + ), + ) + if flows.implicit: + return a2a_pb2.OAuthFlows( + implicit=a2a_pb2.ImplicitOAuthFlow( + authorization_url=flows.implicit.authorizationUrl, + refresh_url=flows.implicit.refreshUrl, + scopes=dict(flows.implicit.scopes.items()), + ), + ) + if flows.password: + return a2a_pb2.OAuthFlows( + password=a2a_pb2.PasswordOAuthFlow( + refresh_url=flows.password.refreshUrl, + scopes=dict(flows.password.scopes.items()), + token_url=flows.password.tokenUrl, + ), + ) + raise ValueError('Unknown oauth flow definition') + + @classmethod + def skill(cls, skill: types.AgentSkill) -> a2a_pb2.AgentSkill: + return a2a_pb2.AgentSkill( + id=skill.id, + name=skill.name, + description=skill.description, + tags=skill.tags, + examples=skill.examples, + input_modes=skill.inputModes, + output_modes=skill.outputModes, + ) + + @classmethod + def role(cls, role: types.Role) -> a2a_pb2.Role: + match role: + case types.Role.user: + return a2a_pb2.Role.ROLE_USER + case types.Role.agent: + return a2a_pb2.Role.ROLE_AGENT + case _: + return a2a_pb2.Role.ROLE_UNSPECIFIED + + +class FromProto: + """Converts proto types to Python types.""" + + @classmethod + def message(cls, message: a2a_pb2.Message) -> types.Message: + return types.Message( + messageId=message.message_id, + parts=[FromProto.part(p) for p in message.content], + contextId=message.context_id, + taskId=message.task_id, + role=FromProto.role(message.role), + metadata=FromProto.metadata(message.metadata), + ) + + @classmethod + def metadata(cls, metadata: struct_pb2.Struct) -> dict[str, Any]: + return { + key: value.string_value + for key, value in metadata.fields.items() + if value.string_value + } + + @classmethod + def part(cls, part: a2a_pb2.Part) -> types.Part: + if part.HasField('text'): + return types.Part(root=types.TextPart(text=part.text)) + if part.HasField('file'): + return types.Part( + root=types.FilePart(file=FromProto.file(part.file)) + ) + if part.HasField('data'): + return types.Part( + root=types.DataPart(data=FromProto.data(part.data)) + ) + raise ValueError(f'Unsupported part type: {part}') + + @classmethod + def data(cls, data: a2a_pb2.DataPart) -> dict[str, Any]: + json_data = json_format.MessageToJson(data.data) + return json.loads(json_data) + + @classmethod + def file( + cls, file: a2a_pb2.FilePart + ) -> types.FileWithUri | types.FileWithBytes: + if file.HasField('file_with_uri'): + return types.FileWithUri(uri=file.file_with_uri) + return types.FileWithBytes(bytes=file.file_with_bytes.decode('utf-8')) + + @classmethod + def task(cls, task: a2a_pb2.Task) -> types.Task: + return types.Task( + id=task.id, + contextId=task.context_id, + status=FromProto.task_status(task.status), + artifacts=[FromProto.artifact(a) for a in task.artifacts], + history=[FromProto.message(h) for h in task.history], + ) + + @classmethod + def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus: + return types.TaskStatus( + state=FromProto.task_state(status.state), + message=FromProto.message(status.update), + ) + + @classmethod + def task_state(cls, state: a2a_pb2.TaskState) -> types.TaskState: + match state: + case a2a_pb2.TaskState.TASK_STATE_SUBMITTED: + return types.TaskState.submitted + case a2a_pb2.TaskState.TASK_STATE_WORKING: + return types.TaskState.working + case a2a_pb2.TaskState.TASK_STATE_COMPLETED: + return types.TaskState.completed + case a2a_pb2.TaskState.TASK_STATE_CANCELLED: + return types.TaskState.canceled + case a2a_pb2.TaskState.TASK_STATE_FAILED: + return types.TaskState.failed + case a2a_pb2.TaskState.TASK_STATE_INPUT_REQUIRED: + return types.TaskState.input_required + case _: + return types.TaskState.unknown + + @classmethod + def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact: + return types.Artifact( + artifactId=artifact.artifact_id, + description=artifact.description, + metadata=FromProto.metadata(artifact.metadata), + name=artifact.name, + parts=[FromProto.part(p) for p in artifact.parts], + ) + + @classmethod + def task_artifact_update_event( + cls, event: a2a_pb2.TaskArtifactUpdateEvent + ) -> types.TaskArtifactUpdateEvent: + return types.TaskArtifactUpdateEvent( + taskId=event.task_id, + contextId=event.context_id, + artifact=FromProto.artifact(event.artifact), + metadata=FromProto.metadata(event.metadata), + append=event.append, + lastChunk=event.last_chunk, + ) + + @classmethod + def task_status_update_event( + cls, event: a2a_pb2.TaskStatusUpdateEvent + ) -> types.TaskStatusUpdateEvent: + return types.TaskStatusUpdateEvent( + taskId=event.task_id, + contextId=event.context_id, + status=FromProto.task_status(event.status), + metadata=FromProto.metadata(event.metadata), + final=event.final, + ) + + @classmethod + def push_notification_config( + cls, config: a2a_pb2.PushNotificationConfig + ) -> types.PushNotificationConfig: + return types.PushNotificationConfig( + id=config.id, + url=config.url, + token=config.token, + authentication=FromProto.authentication_info(config.authentication), + ) + + @classmethod + def authentication_info( + cls, info: a2a_pb2.AuthenticationInfo + ) -> types.PushNotificationAuthenticationInfo: + return types.PushNotificationAuthenticationInfo( + schemes=list(info.schemes), + credentials=info.credentials, + ) + + @classmethod + def message_send_configuration( + cls, config: a2a_pb2.SendMessageConfiguration + ) -> types.MessageSendConfiguration: + return types.MessageSendConfiguration( + acceptedOutputModes=list(config.accepted_output_modes), + pushNotificationConfig=FromProto.push_notification_config( + config.push_notification + ), + historyLength=config.history_length, + blocking=config.blocking, + ) + + @classmethod + def message_send_params( + cls, request: a2a_pb2.SendMessageRequest + ) -> types.MessageSendParams: + return types.MessageSendParams( + configuration=cls.message_send_configuration(request.configuration), + message=cls.message(request.request), + metadata=cls.metadata(request.metadata), + ) + + @classmethod + def task_id_params( + cls, + request: ( + a2a_pb2.CancelTaskRequest + | a2a_pb2.TaskSubscriptionRequest + | a2a_pb2.GetTaskPushNotificationRequest + ), + ) -> types.TaskIdParams: + # This is currently incomplete until the core sdk supports multiple + # configs for a single task. + if isinstance(request, a2a_pb2.GetTaskPushNotificationRequest): + m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, request.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {request.name}' + ) + ) + return types.TaskIdParams(id=m.group(1)) + m = re.match(_TASK_NAME_MATCH, request.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {request.name}' + ) + ) + return types.TaskIdParams(id=m.group(1)) + + @classmethod + def task_push_notification_config( + cls, + request: a2a_pb2.CreateTaskPushNotificationRequest, + ) -> types.TaskPushNotificationConfig: + m = re.match(_TASK_NAME_MATCH, request.parent) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {request.parent}' + ) + ) + return types.TaskPushNotificationConfig( + pushNotificationConfig=cls.push_notification_config( + request.config.push_notification_config, + ), + taskId=m.group(1), + ) + + @classmethod + def agent_card( + cls, + card: a2a_pb2.AgentCard, + ) -> types.AgentCard: + return types.AgentCard( + capabilities=cls.capabilities(card.capabilities), + defaultInputModes=list(card.default_input_modes), + defaultOutputModes=list(card.default_output_modes), + description=card.description, + documentationUrl=card.documentation_url, + name=card.name, + provider=cls.provider(card.provider), + security=cls.security(list(card.security)), + securitySchemes=cls.security_schemes(dict(card.security_schemes)), + skills=[cls.skill(x) for x in card.skills] if card.skills else [], + url=card.url, + version=card.version, + supportsAuthenticatedExtendedCard=card.supports_authenticated_extended_card, + ) + + @classmethod + def task_query_params( + cls, + request: a2a_pb2.GetTaskRequest, + ) -> types.TaskQueryParams: + m = re.match(_TASK_NAME_MATCH, request.name) + if not m: + raise ServerError( + error=types.InvalidParamsError( + message=f'No task for {request.name}' + ) + ) + return types.TaskQueryParams( + historyLength=request.history_length + if request.history_length + else None, + id=m.group(1), + metadata=None, + ) + + @classmethod + def capabilities( + cls, capabilities: a2a_pb2.AgentCapabilities + ) -> types.AgentCapabilities: + return types.AgentCapabilities( + streaming=capabilities.streaming, + pushNotifications=capabilities.push_notifications, + ) + + @classmethod + def security( + cls, + security: list[a2a_pb2.Security] | None, + ) -> list[dict[str, list[str]]] | None: + if not security: + return None + rval: list[dict[str, list[str]]] = [] + for s in security: + rval.append({k: list(v.list) for (k, v) in s.schemes.items()}) + return rval + + @classmethod + def provider( + cls, provider: a2a_pb2.AgentProvider | None + ) -> types.AgentProvider | None: + if not provider: + return None + return types.AgentProvider( + organization=provider.organization, + url=provider.url, + ) + + @classmethod + def security_schemes( + cls, schemes: dict[str, a2a_pb2.SecurityScheme] + ) -> dict[str, types.SecurityScheme]: + return {k: cls.security_scheme(v) for (k, v) in schemes.items()} + + @classmethod + def security_scheme( + cls, + scheme: a2a_pb2.SecurityScheme, + ) -> types.SecurityScheme: + if scheme.HasField('api_key_security_scheme'): + return types.SecurityScheme( + root=types.APIKeySecurityScheme( + description=scheme.api_key_security_scheme.description, + name=scheme.api_key_security_scheme.name, + in_=scheme.api_key_security_scheme.location, # type: ignore[call-arg] + ) + ) + if scheme.HasField('http_auth_security_scheme'): + return types.SecurityScheme( + root=types.HTTPAuthSecurityScheme( + description=scheme.http_auth_security_scheme.description, + scheme=scheme.http_auth_security_scheme.scheme, + bearerFormat=scheme.http_auth_security_scheme.bearer_format, + ) + ) + if scheme.HasField('oauth2_security_scheme'): + return types.SecurityScheme( + root=types.OAuth2SecurityScheme( + description=scheme.oauth2_security_scheme.description, + flows=cls.oauth2_flows(scheme.oauth2_security_scheme.flows), + ) + ) + return types.SecurityScheme( + root=types.OpenIdConnectSecurityScheme( + description=scheme.open_id_connect_security_scheme.description, + openIdConnectUrl=scheme.open_id_connect_security_scheme.open_id_connect_url, + ) + ) + + @classmethod + def oauth2_flows(cls, flows: a2a_pb2.OAuthFlows) -> types.OAuthFlows: + if flows.HasField('authorization_code'): + return types.OAuthFlows( + authorizationCode=types.AuthorizationCodeOAuthFlow( + authorizationUrl=flows.authorization_code.authorization_url, + refreshUrl=flows.authorization_code.refresh_url, + scopes=dict(flows.authorization_code.scopes.items()), + tokenUrl=flows.authorization_code.token_url, + ), + ) + if flows.HasField('client_credentials'): + return types.OAuthFlows( + clientCredentials=types.ClientCredentialsOAuthFlow( + refreshUrl=flows.client_credentials.refresh_url, + scopes=dict(flows.client_credentials.scopes.items()), + tokenUrl=flows.client_credentials.token_url, + ), + ) + if flows.HasField('implicit'): + return types.OAuthFlows( + implicit=types.ImplicitOAuthFlow( + authorizationUrl=flows.implicit.authorization_url, + refreshUrl=flows.implicit.refresh_url, + scopes=dict(flows.implicit.scopes.items()), + ), + ) + return types.OAuthFlows( + password=types.PasswordOAuthFlow( + refreshUrl=flows.password.refresh_url, + scopes=dict(flows.password.scopes.items()), + tokenUrl=flows.password.token_url, + ), + ) + + @classmethod + def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill: + return types.AgentSkill( + id=skill.id, + name=skill.name, + description=skill.description, + tags=list(skill.tags), + examples=list(skill.examples), + inputModes=list(skill.input_modes), + outputModes=list(skill.output_modes), + ) + + @classmethod + def role(cls, role: a2a_pb2.Role) -> types.Role: + match role: + case a2a_pb2.Role.ROLE_USER: + return types.Role.user + case a2a_pb2.Role.ROLE_AGENT: + return types.Role.agent + case _: + return types.Role.agent