Skip to content

Commit 6fa6a6c

Browse files
authored
refactor: Move agent card signature verification into A2ACardResolver (#593)
# Description Previously, the `JSON-RPC` and `REST` protocols verified agent card signatures after calling `A2ACardResolver.get_agent_card`. This change moves the signature verification logic inside the `A2ACardResolver.get_agent_card` method and adds a unit test to test_card_resolver.py Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [x] Appropriate docs were updated (if necessary)
1 parent 3deecc4 commit 6fa6a6c

File tree

7 files changed

+44
-10
lines changed

7 files changed

+44
-10
lines changed

src/a2a/client/card_resolver.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33

4+
from collections.abc import Callable
45
from typing import Any
56

67
import httpx
@@ -44,6 +45,7 @@ async def get_agent_card(
4445
self,
4546
relative_card_path: str | None = None,
4647
http_kwargs: dict[str, Any] | None = None,
48+
signature_verifier: Callable[[AgentCard], None] | None = None,
4749
) -> AgentCard:
4850
"""Fetches an agent card from a specified path relative to the base_url.
4951
@@ -56,6 +58,7 @@ async def get_agent_card(
5658
agent card path. Use `'/'` for an empty path.
5759
http_kwargs: Optional dictionary of keyword arguments to pass to the
5860
underlying httpx.get request.
61+
signature_verifier: A callable used to verify the agent card's signatures.
5962
6063
Returns:
6164
An `AgentCard` object representing the agent's capabilities.
@@ -86,6 +89,8 @@ async def get_agent_card(
8689
agent_card_data,
8790
)
8891
agent_card = AgentCard.model_validate(agent_card_data)
92+
if signature_verifier:
93+
signature_verifier(agent_card)
8994
except httpx.HTTPStatusError as e:
9095
raise A2AClientHTTPError(
9196
e.response.status_code,

src/a2a/client/client_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ async def connect( # noqa: PLR0913
116116
resolver_http_kwargs: dict[str, Any] | None = None,
117117
extra_transports: dict[str, TransportProducer] | None = None,
118118
extensions: list[str] | None = None,
119+
signature_verifier: Callable[[AgentCard], None] | None = None,
119120
) -> Client:
120121
"""Convenience method for constructing a client.
121122
@@ -146,6 +147,7 @@ async def connect( # noqa: PLR0913
146147
extra_transports: Additional transport protocols to enable when
147148
constructing the client.
148149
extensions: List of extensions to be activated.
150+
signature_verifier: A callable used to verify the agent card's signatures.
149151
150152
Returns:
151153
A `Client` object.
@@ -158,12 +160,14 @@ async def connect( # noqa: PLR0913
158160
card = await resolver.get_agent_card(
159161
relative_card_path=relative_card_path,
160162
http_kwargs=resolver_http_kwargs,
163+
signature_verifier=signature_verifier,
161164
)
162165
else:
163166
resolver = A2ACardResolver(client_config.httpx_client, agent)
164167
card = await resolver.get_agent_card(
165168
relative_card_path=relative_card_path,
166169
http_kwargs=resolver_http_kwargs,
170+
signature_verifier=signature_verifier,
167171
)
168172
else:
169173
card = agent

src/a2a/client/transports/grpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ async def get_card(
237237
metadata=self._get_grpc_metadata(extensions),
238238
)
239239
card = proto_utils.FromProto.agent_card(card_pb)
240-
if signature_verifier is not None:
240+
if signature_verifier:
241241
signature_verifier(card)
242242

243243
self.agent_card = card

src/a2a/client/transports/jsonrpc.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,10 @@ async def get_card(
390390

391391
if not card:
392392
resolver = A2ACardResolver(self.httpx_client, self.url)
393-
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
394-
if signature_verifier is not None:
395-
signature_verifier(card)
393+
card = await resolver.get_agent_card(
394+
http_kwargs=modified_kwargs,
395+
signature_verifier=signature_verifier,
396+
)
396397
self._needs_extended_card = (
397398
card.supports_authenticated_extended_card
398399
)
@@ -418,7 +419,7 @@ async def get_card(
418419
if isinstance(response.root, JSONRPCErrorResponse):
419420
raise A2AClientJSONRPCError(response.root)
420421
card = response.root.result
421-
if signature_verifier is not None:
422+
if signature_verifier:
422423
signature_verifier(card)
423424

424425
self.agent_card = card

src/a2a/client/transports/rest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,10 @@ async def get_card(
382382

383383
if not card:
384384
resolver = A2ACardResolver(self.httpx_client, self.url)
385-
card = await resolver.get_agent_card(http_kwargs=modified_kwargs)
386-
if signature_verifier is not None:
387-
signature_verifier(card)
385+
card = await resolver.get_agent_card(
386+
http_kwargs=modified_kwargs,
387+
signature_verifier=signature_verifier,
388+
)
388389
self._needs_extended_card = (
389390
card.supports_authenticated_extended_card
390391
)
@@ -402,7 +403,7 @@ async def get_card(
402403
'/v1/card', {}, modified_kwargs
403404
)
404405
card = AgentCard.model_validate(response_data)
405-
if signature_verifier is not None:
406+
if signature_verifier:
406407
signature_verifier(card)
407408

408409
self.agent_card = card

tests/client/test_card_resolver.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33

4-
from unittest.mock import AsyncMock, Mock, patch
4+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
55

66
import httpx
77
import pytest
@@ -371,9 +371,30 @@ async def test_get_agent_card_returns_agent_card_instance(
371371
self, resolver, mock_httpx_client, mock_response, valid_agent_card_data
372372
):
373373
"""Test that get_agent_card returns an AgentCard instance."""
374+
mock_response.json.return_value = valid_agent_card_data
375+
mock_httpx_client.get.return_value = mock_response
374376
mock_agent_card = Mock(spec=AgentCard)
377+
375378
with patch.object(
376379
AgentCard, 'model_validate', return_value=mock_agent_card
377380
):
378381
result = await resolver.get_agent_card()
379382
assert result == mock_agent_card
383+
mock_response.raise_for_status.assert_called_once()
384+
385+
@pytest.mark.asyncio
386+
async def test_get_agent_card_with_signature_verifier(
387+
self, resolver, mock_httpx_client, valid_agent_card_data
388+
):
389+
"""Test that the signature verifier is called if provided."""
390+
mock_verifier = MagicMock()
391+
392+
mock_response = MagicMock(spec=httpx.Response)
393+
mock_response.json.return_value = valid_agent_card_data
394+
mock_httpx_client.get.return_value = mock_response
395+
396+
agent_card = await resolver.get_agent_card(
397+
signature_verifier=mock_verifier
398+
)
399+
400+
mock_verifier.assert_called_once_with(agent_card)

tests/client/test_client_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ async def test_client_factory_connect_with_resolver_args(
190190
mock_resolver.return_value.get_agent_card.assert_awaited_once_with(
191191
relative_card_path=relative_path,
192192
http_kwargs=http_kwargs,
193+
signature_verifier=None,
193194
)
194195

195196

@@ -216,6 +217,7 @@ async def test_client_factory_connect_resolver_args_without_client(
216217
mock_resolver.return_value.get_agent_card.assert_awaited_once_with(
217218
relative_card_path=relative_path,
218219
http_kwargs=http_kwargs,
220+
signature_verifier=None,
219221
)
220222

221223

0 commit comments

Comments
 (0)