Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"]
postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
signing = ["python-jose>=3.0.0"]

sql = ["a2a-sdk[postgresql,mysql,sqlite]"]

Expand All @@ -45,6 +46,7 @@ all = [
"a2a-sdk[encryption]",
"a2a-sdk[grpc]",
"a2a-sdk[telemetry]",
"a2a-sdk[signing]",
]

[project.urls]
Expand Down
9 changes: 7 additions & 2 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Callable
from typing import Any

from a2a.client.client import (
Expand Down Expand Up @@ -256,11 +256,12 @@
):
yield await self._process_response(tracker, event)

async def get_card(

Check failure on line 259 in src/a2a/client/base_client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (D417)

src/a2a/client/base_client.py:259:15: D417 Missing argument description in the docstring for `get_card`: `signature_verifier`
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card.

Expand All @@ -270,12 +271,16 @@
Args:
context: The client call context.
extensions: List of extensions to be activated.
key_provider: A callable that takes key-id (kid) and JSON web key url (jku)

Check failure on line 274 in src/a2a/client/base_client.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jku` is not a recognized word. (unrecognized-spelling)
and returns the verification key for signature verification.

Returns:
The `AgentCard` for the agent.
"""
card = await self._transport.get_card(
context=context, extensions=extensions
context=context,
extensions=extensions,
signature_verifier=signature_verifier,
)
self._card = card
return card
Expand Down
1 change: 1 addition & 0 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,18 @@
extensions: list[str] | None = None,
) -> AsyncIterator[ClientEvent]:
"""Resubscribes to a task's event stream."""
return
yield

@abstractmethod
async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""

Check notice on line 190 in src/a2a/client/client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/base.py (97-108)

async def add_event_consumer(self, consumer: Consumer) -> None:
"""Attaches additional consumers to the `Client`."""
Expand Down
3 changes: 2 additions & 1 deletion src/a2a/client/transports/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable

from a2a.client.middleware import ClientCallContext
from a2a.types import (
Expand Down Expand Up @@ -94,17 +94,18 @@
Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
]:
"""Reconnects to get task updates."""
return
yield

@abstractmethod
async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the AgentCard."""

Check notice on line 108 in src/a2a/client/transports/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/client.py (179-190)

@abstractmethod
async def close(self) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable


try:
Expand Down Expand Up @@ -223,6 +223,7 @@ async def get_card(
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
card = self.agent_card
Expand All @@ -236,6 +237,9 @@ async def get_card(
metadata=self._get_grpc_metadata(extensions),
)
card = proto_utils.FromProto.agent_card(card_pb)
if signature_verifier is not None:
signature_verifier(card)

self.agent_card = card
self._needs_extended_card = False
return card
Expand Down
14 changes: 11 additions & 3 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging

from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from typing import Any
from uuid import uuid4

Expand Down Expand Up @@ -363,41 +363,45 @@
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card

if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
if signature_verifier is not None:
signature_verifier(card)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)
self.agent_card = card

if not self._needs_extended_card:
return card

request = GetAuthenticatedExtendedCardRequest(id=str(uuid4()))

Check notice on line 404 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (162-396)
payload, modified_kwargs = await self._apply_interceptors(
request.method,
request.model_dump(mode='json', exclude_none=True),
Expand All @@ -413,9 +417,13 @@
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
self.agent_card = response.root.result
card = response.root.result
if signature_verifier is not None:
signature_verifier(card)

self.agent_card = card
self._needs_extended_card = False
return self.agent_card
return card

async def close(self) -> None:
"""Closes the httpx client."""
Expand Down
9 changes: 8 additions & 1 deletion src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging

from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from typing import Any

import httpx
Expand Down Expand Up @@ -159,237 +159,241 @@
yield proto_utils.FromProto.stream_response(event)
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
try:
response = await self.httpx_client.send(request)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_post_request(
self,
target: str,
rpc_request_payload: dict[str, Any],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
return await self._send_request(
self.httpx_client.build_request(
'POST',
f'{self.url}{target}',
json=rpc_request_payload,
**(http_kwargs or {}),
)
)

async def _send_get_request(
self,
target: str,
query_params: dict[str, str],
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
return await self._send_request(
self.httpx_client.build_request(
'GET',
f'{self.url}{target}',
params=query_params,
**(http_kwargs or {}),
)
)

async def get_task(
self,
request: TaskQueryParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
_payload, modified_kwargs = await self._apply_interceptors(
request.model_dump(mode='json', exclude_none=True),
modified_kwargs,
context,
)
response_data = await self._send_get_request(
f'/v1/tasks/{request.id}',
{'historyLength': str(request.history_length)}
if request.history_length is not None
else {},
modified_kwargs,
)
task = a2a_pb2.Task()
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def cancel_task(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""
pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}')
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload,
modified_kwargs,
context,
)
response_data = await self._send_post_request(
f'/v1/tasks/{request.id}:cancel', payload, modified_kwargs
)
task = a2a_pb2.Task()
ParseDict(response_data, task)
return proto_utils.FromProto.task(task)

async def set_task_callback(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""
pb = a2a_pb2.CreateTaskPushNotificationConfigRequest(
parent=f'tasks/{request.task_id}',
config_id=request.push_notification_config.id,
config=proto_utils.ToProto.task_push_notification_config(request),
)
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload, modified_kwargs, context
)
response_data = await self._send_post_request(
f'/v1/tasks/{request.task_id}/pushNotificationConfigs',
payload,
modified_kwargs,
)
config = a2a_pb2.TaskPushNotificationConfig()
ParseDict(response_data, config)
return proto_utils.FromProto.task_push_notification_config(config)

async def get_task_callback(
self,
request: GetTaskPushNotificationConfigParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""
pb = a2a_pb2.GetTaskPushNotificationConfigRequest(
name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
)
payload = MessageToDict(pb)
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
payload, modified_kwargs = await self._apply_interceptors(
payload,
modified_kwargs,
context,
)
response_data = await self._send_get_request(
f'/v1/tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}',
{},
modified_kwargs,
)
config = a2a_pb2.TaskPushNotificationConfig()
ParseDict(response_data, config)
return proto_utils.FromProto.task_push_notification_config(config)

async def resubscribe(
self,
request: TaskIdParams,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
) -> AsyncGenerator[
Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message
]:
"""Reconnects to get task updates."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
modified_kwargs.setdefault('timeout', None)

async with aconnect_sse(
self.httpx_client,
'GET',
f'{self.url}/v1/tasks/{request.id}:subscribe',
**modified_kwargs,
) as event_source:
try:
async for sse in event_source.aiter_sse():
event = a2a_pb2.StreamResponse()
Parse(sse.data, event)
yield proto_utils.FromProto.stream_response(event)
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def get_card(
self,
*,
context: ClientCallContext | None = None,
extensions: list[str] | None = None,
signature_verifier: Callable[[AgentCard], None] | None = None,
) -> AgentCard:
"""Retrieves the agent's card."""
modified_kwargs = update_extension_header(
self._get_http_args(context),
extensions if extensions is not None else self.extensions,
)
card = self.agent_card

if not card:
resolver = A2ACardResolver(self.httpx_client, self.url)
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
if signature_verifier is not None:
signature_verifier(card)
self._needs_extended_card = (
card.supports_authenticated_extended_card
)
self.agent_card = card

if not self._needs_extended_card:
return card

_, modified_kwargs = await self._apply_interceptors(

Check notice on line 396 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (366-404)
{},
modified_kwargs,
context,
Expand All @@ -398,6 +402,9 @@
'/v1/card', {}, modified_kwargs
)
card = AgentCard.model_validate(response_data)
if signature_verifier is not None:
signature_verifier(card)

self.agent_card = card
self._needs_extended_card = False
return card
Expand Down
28 changes: 28 additions & 0 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,21 @@ def agent_card(
]
if card.additional_interfaces
else None,
signatures=[cls.agent_card_signature(x) for x in card.signatures]
if card.signatures
else None,
)

@classmethod
def agent_card_signature(
cls, signature: types.AgentCardSignature
) -> a2a_pb2.AgentCardSignature:
return a2a_pb2.AgentCardSignature(
protected=signature.protected,
signature=signature.signature,
header=dict_to_struct(signature.header)
if signature.header is not None
else None,
)

@classmethod
Expand Down Expand Up @@ -865,6 +880,19 @@ def agent_card(
]
if card.additional_interfaces
else None,
signatures=[cls.agent_card_signature(x) for x in card.signatures]
if card.signatures
else None,
)

@classmethod
def agent_card_signature(
cls, signature: a2a_pb2.AgentCardSignature
) -> types.AgentCardSignature:
return types.AgentCardSignature(
protected=signature.protected,
signature=signature.signature,
header=json_format.MessageToDict(signature.header),
)

@classmethod
Expand Down
161 changes: 161 additions & 0 deletions src/a2a/utils/signing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import json

from collections.abc import Callable
from typing import Any


try:
from jose import jws

Check failure on line 8 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jws` is not a recognized word. (unrecognized-spelling)
from jose.backends.base import Key
from jose.exceptions import JOSEError
from jose.utils import base64url_decode, base64url_encode
except ImportError as e:
raise ImportError(
'A2AUtilsSigning requires python-jose to be installed. '

Check failure on line 14 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`AUtils` is not a recognized word. (unrecognized-spelling)
'Install with: '
"'pip install a2a-sdk[signing]'"
) from e

from a2a.types import AgentCard, AgentCardSignature


def clean_empty(d: Any) -> Any:
"""Recursively remove empty lists, dicts, strings, and None values from a dictionary."""
if isinstance(d, dict):
cleaned = {k: clean_empty(v) for k, v in d.items()}
return {
k: v
for k, v in cleaned.items()
if v is not None and (isinstance(v, (bool, int, float)) or v)
}
if isinstance(d, list):
cleaned = [clean_empty(v) for v in d]
return [
v
for v in cleaned
if v is not None and (isinstance(v, (bool, int, float)) or v)
]
return d if d not in [None, '', [], {}] else None


def canonicalize_agent_card(agent_card: AgentCard) -> str:
"""Canonicalizes the Agent Card JSON according to RFC 8785 (JCS)."""
card_dict = agent_card.model_dump(
exclude={'signatures'},
exclude_defaults=True,
by_alias=True,
)
# Ensure 'protocol_version' is always included
protocol_version_alias = (
AgentCard.model_fields['protocol_version'].alias or 'protocol_version'
)
if protocol_version_alias not in card_dict:
card_dict[protocol_version_alias] = agent_card.protocol_version

# Recursively remove empty/None values
cleaned_dict = clean_empty(card_dict)

return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True)


def create_agent_card_signer(
signing_key: str | bytes | dict[str, Any] | Key,
kid: str,
alg: str = 'HS256',
jku: str | None = None,

Check failure on line 65 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jku` is not a recognized word. (unrecognized-spelling)

Check warning on line 65 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jku` is not a recognized word. (unrecognized-spelling)
) -> Callable[[AgentCard], AgentCard]:
"""Creates a function that signs an AgentCard and adds the signature.
Args:
signing_key: The private key for signing.
kid: Key ID for the signing key.
alg: The algorithm to use (e.g., "ES256", "RS256").
jku: Optional URL to the JWKS.

Check failure on line 73 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jku` is not a recognized word. (unrecognized-spelling)

Check failure on line 73 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`JWKS` is not a recognized word. (unrecognized-spelling)

Check warning on line 73 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jku` is not a recognized word. (unrecognized-spelling)
Returns:
A callable that takes an AgentCard and returns the modified AgentCard with a signature.
"""

def agent_card_signer(agent_card: AgentCard) -> AgentCard:
"""The actual card_modifier function."""
canonical_payload = canonicalize_agent_card(agent_card)

headers = {'kid': kid, 'typ': 'JOSE'}
if jku:

Check failure on line 84 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jku` is not a recognized word. (unrecognized-spelling)

Check warning on line 84 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jku` is not a recognized word. (unrecognized-spelling)
headers['jku'] = jku

Check failure on line 85 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jku` is not a recognized word. (unrecognized-spelling)

Check warning on line 85 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jku` is not a recognized word. (unrecognized-spelling)

jws_string = jws.sign(

Check failure on line 87 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jws` is not a recognized word. (unrecognized-spelling)

Check failure on line 87 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jws` is not a recognized word. (unrecognized-spelling)

Check warning on line 87 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jws` is not a recognized word. (unrecognized-spelling)

Check warning on line 87 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jws` is not a recognized word. (unrecognized-spelling)
payload=canonical_payload.encode('utf-8'),
key=signing_key,
headers=headers,
algorithm=alg,
)

# The result of jws.sign is a compact serialization: HEADER.PAYLOAD.SIGNATURE

Check warning on line 94 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jws` is not a recognized word. (unrecognized-spelling)
protected_header, _, signature = jws_string.split('.')

Check warning on line 95 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`jws` is not a recognized word. (unrecognized-spelling)

agent_card_signature = AgentCardSignature(
protected=protected_header,
signature=signature,
)

agent_card.signatures = (agent_card.signatures or []) + [
agent_card_signature
]
return agent_card

return agent_card_signer


def create_signature_verifier(
key_provider: Callable[
[str | None, str | None], str | bytes | dict[str, Any] | Key
],
) -> Callable[[AgentCard], None]:
"""Creates a function that verifies AgentCard signatures.
Args:
key_provider: A callable that takes key-id (kid) and JSON web key url (jku) and returns the verification key.
Returns:
A callable that takes an AgentCard, and raises an error if none of the signatures are valid.
"""

def signature_verifier(
agent_card: AgentCard,
) -> None:
"""The actual signature_verifier function."""
if not agent_card.signatures:
raise JOSEError('No signatures found on AgentCard')

last_error = None
for agent_card_signature in agent_card.signatures:
try:
# fetch kid and jku from protected header
protected_header_json = base64url_decode(
agent_card_signature.protected.encode('utf-8')
).decode('utf-8')
protected_header = json.loads(protected_header_json)
kid = protected_header.get('kid')
jku = protected_header.get('jku')
verification_key = key_provider(kid, jku)

canonical_payload = canonicalize_agent_card(agent_card)
encoded_payload = base64url_encode(
canonical_payload.encode('utf-8')
).decode('utf-8')
token = f'{agent_card_signature.protected}.{encoded_payload}.{agent_card_signature.signature}'

jws.verify(
token=token,
key=verification_key,
algorithms=None,
)
return # Found a valid signature

Check failure on line 154 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (TRY300)

src/a2a/utils/signing.py:154:17: TRY300 Consider moving this statement to an `else` block

except JOSEError as e:
last_error = e
continue

Check failure on line 158 in src/a2a/utils/signing.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (PERF203)

src/a2a/utils/signing.py:156:13: PERF203 `try`-`except` within a loop incurs performance overhead
raise JOSEError('No valid signature found') from last_error

return signature_verifier
Loading
Loading