Skip to content

Commit 701eb39

Browse files
committed
fix: replace python-jose library with PyJWK. Add signature verification errors. Add class ProtectedHeader(TypedDict) to signing.py.
1 parent 57e4a68 commit 701eb39

File tree

4 files changed

+148
-80
lines changed

4 files changed

+148
-80
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"]
3636
postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
3737
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
3838
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
39-
signing = ["python-jose>=3.0.0"]
39+
signing = ["PyJWT>=2.0.0"]
4040

4141
sql = ["a2a-sdk[postgresql,mysql,sqlite]"]
4242

src/a2a/utils/signing.py

Lines changed: 70 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,111 @@
11
import json
22

33
from collections.abc import Callable
4-
from typing import Any
4+
from typing import Any, TypedDict
55

66

77
try:
8-
from jose import jws
9-
from jose.backends.base import Key
10-
from jose.exceptions import JOSEError
11-
from jose.utils import base64url_decode, base64url_encode
8+
import jwt
9+
10+
from jwt.api_jwk import PyJWK
11+
from jwt.exceptions import PyJWTError
12+
from jwt.utils import base64url_decode, base64url_encode
1213
except ImportError as e:
1314
raise ImportError(
14-
'A2A Signing requires python-jose to be installed. '
15+
'A2A Signing requires PyJWT to be installed. '
1516
'Install with: '
1617
"'pip install a2a-sdk[signing]'"
1718
) from e
1819

1920
from a2a.types import AgentCard, AgentCardSignature
2021

2122

23+
class SignatureVerificationError(Exception):
24+
"""Base exception for signature verification errors."""
25+
26+
27+
class NoSignatureError(SignatureVerificationError):
28+
"""Exception raised when no signature is found on an AgentCard."""
29+
30+
31+
class InvalidSignaturesError(SignatureVerificationError):
32+
"""Exception raised when all signatures are invalid."""
33+
34+
35+
class ProtectedHeader(TypedDict):
36+
"""Protected header parameters for JWS (JSON Web Signature)."""
37+
38+
kid: str
39+
""" Key identifier. """
40+
alg: str | None
41+
""" Algorithm used for signing. """
42+
jku: str | None
43+
""" JSON Web Key Set URL. """
44+
typ: str | None
45+
""" Token type.
46+
47+
Best practice: SHOULD be "JOSE" for JWS tokens.
48+
"""
49+
50+
2251
def clean_empty(d: Any) -> Any:
23-
"""Recursively remove empty strings, lists, dicts, and None values from a dictionary."""
52+
"""Recursively remove empty strings, lists and dicts from a dictionary."""
2453
if isinstance(d, dict):
2554
cleaned_dict: dict[Any, Any] = {k: clean_empty(v) for k, v in d.items()}
26-
return {
27-
k: v
28-
for k, v in cleaned_dict.items()
29-
if v is not None and (isinstance(v, (bool, int, float)) or v)
30-
}
55+
return {k: v for k, v in cleaned_dict.items() if v}
3156
if isinstance(d, list):
3257
cleaned_list: list[Any] = [clean_empty(v) for v in d]
33-
return [
34-
v
35-
for v in cleaned_list
36-
if v is not None and (isinstance(v, (bool, int, float)) or v)
37-
]
38-
return d if d not in ['', [], {}, None] else None
58+
return [v for v in cleaned_list if v]
59+
return d if d not in ['', [], {}] else None
3960

4061

4162
def canonicalize_agent_card(agent_card: AgentCard) -> str:
4263
"""Canonicalizes the Agent Card JSON according to RFC 8785 (JCS)."""
4364
card_dict = agent_card.model_dump(
4465
exclude={'signatures'},
4566
exclude_defaults=True,
67+
exclude_none=True,
4668
by_alias=True,
4769
)
48-
# Ensure 'protocol_version' is always included
49-
protocol_version_alias = (
50-
AgentCard.model_fields['protocol_version'].alias or 'protocol_version'
51-
)
52-
if protocol_version_alias not in card_dict:
53-
card_dict[protocol_version_alias] = agent_card.protocol_version
54-
55-
# Recursively remove empty/None values
70+
# Recursively remove empty values
5671
cleaned_dict = clean_empty(card_dict)
57-
5872
return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True)
5973

6074

6175
def create_agent_card_signer(
62-
signing_key: str | bytes | dict[str, Any] | Key,
63-
kid: str,
64-
alg: str = 'HS256',
65-
jku: str | None = None,
76+
signing_key: PyJWK | str | bytes,
77+
protected_header: ProtectedHeader,
78+
header: dict[str, Any] | None = None,
6679
) -> Callable[[AgentCard], AgentCard]:
6780
"""Creates a function that signs an AgentCard and adds the signature.
6881
6982
Args:
7083
signing_key: The private key for signing.
71-
kid: Key ID for the signing key.
72-
alg: The algorithm to use (e.g., "ES256", "RS256").
73-
jku: Optional URL to the JSON Web Keys.
84+
protected_header: The protected header parameters.
85+
header: Unprotected header parameters.
7486
7587
Returns:
7688
A callable that takes an AgentCard and returns the modified AgentCard with a signature.
7789
"""
7890

7991
def agent_card_signer(agent_card: AgentCard) -> AgentCard:
80-
"""The actual card_modifier function."""
92+
"""Signs agent card."""
8193
canonical_payload = canonicalize_agent_card(agent_card)
94+
payload_dict = json.loads(canonical_payload)
8295

83-
headers = {'kid': kid, 'typ': 'JOSE'}
84-
if jku:
85-
headers['jku'] = jku
86-
87-
jws_string = jws.sign(
88-
payload=canonical_payload.encode('utf-8'),
96+
jws_string = jwt.encode(
97+
payload=payload_dict,
8998
key=signing_key,
90-
headers=headers,
91-
algorithm=alg,
99+
algorithm=protected_header.get('alg', 'HS256'),
100+
headers=protected_header,
92101
)
93102

94-
# The result of jws.sign is a compact serialization: HEADER.PAYLOAD.SIGNATURE
95-
protected_header, _, signature = jws_string.split('.')
103+
# The result of jwt.encode is a compact serialization: HEADER.PAYLOAD.SIGNATURE
104+
protected, _, signature = jws_string.split('.')
96105

97106
agent_card_signature = AgentCardSignature(
98-
protected=protected_header,
107+
header=header,
108+
protected=protected,
99109
signature=signature,
100110
)
101111

@@ -108,9 +118,7 @@ def agent_card_signer(agent_card: AgentCard) -> AgentCard:
108118

109119

110120
def create_signature_verifier(
111-
key_provider: Callable[
112-
[str | None, str | None], str | bytes | dict[str, Any] | Key
113-
],
121+
key_provider: Callable[[str | None, str | None], PyJWK | str | bytes],
114122
algorithms: list[str],
115123
) -> Callable[[AgentCard], None]:
116124
"""Creates a function that verifies AgentCard signatures.
@@ -126,14 +134,17 @@ def create_signature_verifier(
126134
def signature_verifier(
127135
agent_card: AgentCard,
128136
) -> None:
129-
"""The actual signature_verifier function."""
137+
"""Verifies agent card signatures.
138+
139+
Checks if at least one signature matches the key, otherwise raises an error.
140+
"""
130141
if not agent_card.signatures:
131-
raise JOSEError('No signatures found on AgentCard')
142+
raise NoSignatureError('AgentCard has no signatures to verify.')
132143

133144
last_error = None
134145
for agent_card_signature in agent_card.signatures:
135146
try:
136-
# fetch kid and jku from protected header
147+
# get verification key
137148
protected_header_json = base64url_decode(
138149
agent_card_signature.protected.encode('utf-8')
139150
).decode('utf-8')
@@ -146,20 +157,22 @@ def signature_verifier(
146157
encoded_payload = base64url_encode(
147158
canonical_payload.encode('utf-8')
148159
).decode('utf-8')
149-
token = f'{agent_card_signature.protected}.{encoded_payload}.{agent_card_signature.signature}'
150160

151-
jws.verify(
152-
token=token,
161+
token = f'{agent_card_signature.protected}.{encoded_payload}.{agent_card_signature.signature}'
162+
jwt.decode(
163+
jwt=token,
153164
key=verification_key,
154165
algorithms=algorithms,
155166
)
156167
# Found a valid signature, exit the loop and function
157168
break
158-
except JOSEError as e:
169+
except PyJWTError as e:
159170
last_error = e
160171
continue
161172
else:
162173
# This block runs only if the loop completes without a break
163-
raise JOSEError('No valid signature found') from last_error
174+
raise InvalidSignaturesError(
175+
'No valid signature found'
176+
) from last_error
164177

165178
return signature_verifier

tests/integration/test_client_server_integration.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import pytest
99
import pytest_asyncio
1010
from grpc.aio import Channel
11-
from jose.backends.base import Key
1211

12+
from jwt.api_jwk import PyJWK
1313
from a2a.client import ClientConfig
1414
from a2a.client.base_client import BaseClient
1515
from a2a.client.transports import JsonRpcTransport, RestTransport
@@ -89,7 +89,7 @@
8989
)
9090

9191

92-
def create_key_provider(verification_key: str | bytes | dict[str, Any] | Key):
92+
def create_key_provider(verification_key: PyJWK | str | bytes):
9393
"""Creates a key provider function for testing."""
9494

9595
def key_provider(kid: str | None, jku: str | None):
@@ -754,6 +754,7 @@ async def test_http_transport_get_authenticated_card(
754754
transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card)
755755
result = await transport.get_card()
756756
assert result.name == extended_agent_card.name
757+
assert transport.agent_card is not None
757758
assert transport.agent_card.name == extended_agent_card.name
758759
assert transport._needs_extended_card is False
759760

@@ -776,6 +777,7 @@ def channel_factory(address: str) -> Channel:
776777
transport = GrpcTransport(channel=channel, agent_card=agent_card)
777778

778779
# The transport starts with a minimal card, get_card() fetches the full one
780+
assert transport.agent_card is not None
779781
transport.agent_card.supports_authenticated_extended_card = True
780782
result = await transport.get_card()
781783

@@ -845,7 +847,7 @@ async def test_json_transport_base_client_send_message_with_extensions(
845847

846848

847849
@pytest.mark.asyncio
848-
async def test_json_transport_get_signed_base_card_no_initial(
850+
async def test_json_transport_get_signed_base_card(
849851
jsonrpc_setup: TransportSetup, agent_card: AgentCard
850852
) -> None:
851853
"""Tests fetching and verifying a symmetrically signed AgentCard via JSON-RPC.
@@ -860,7 +862,13 @@ async def test_json_transport_get_signed_base_card_no_initial(
860862
# Setup signing on the server side
861863
key = 'key12345'
862864
signer = create_agent_card_signer(
863-
signing_key=key, alg='HS384', kid='testkey'
865+
signing_key=key,
866+
protected_header={
867+
'alg': 'HS384',
868+
'kid': 'testkey',
869+
'jku': None,
870+
'typ': 'JOSE',
871+
},
864872
)
865873

866874
app_builder = A2AFastAPIApplication(
@@ -885,6 +893,7 @@ async def test_json_transport_get_signed_base_card_no_initial(
885893
assert result.name == agent_card.name
886894
assert result.signatures is not None
887895
assert len(result.signatures) == 1
896+
assert transport.agent_card is not None
888897
assert transport.agent_card.name == agent_card.name
889898
assert transport._needs_extended_card is False
890899

@@ -911,7 +920,13 @@ async def test_json_transport_get_signed_extended_card(
911920
private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1())
912921
public_key = private_key.public_key()
913922
signer = create_agent_card_signer(
914-
signing_key=private_key, alg='ES256', kid='testkey'
923+
signing_key=private_key,
924+
protected_header={
925+
'alg': 'ES256',
926+
'kid': 'testkey',
927+
'jku': None,
928+
'typ': 'JOSE',
929+
},
915930
)
916931

917932
app_builder = A2AFastAPIApplication(
@@ -937,6 +952,7 @@ async def test_json_transport_get_signed_extended_card(
937952
assert result.name == extended_agent_card.name
938953
assert result.signatures is not None
939954
assert len(result.signatures) == 1
955+
assert transport.agent_card is not None
940956
assert transport.agent_card.name == extended_agent_card.name
941957
assert transport._needs_extended_card is False
942958

@@ -964,7 +980,13 @@ async def test_json_transport_get_signed_base_and_extended_cards(
964980
private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1())
965981
public_key = private_key.public_key()
966982
signer = create_agent_card_signer(
967-
signing_key=private_key, alg='ES256', kid='testkey'
983+
signing_key=private_key,
984+
protected_header={
985+
'alg': 'ES256',
986+
'kid': 'testkey',
987+
'jku': None,
988+
'typ': 'JOSE',
989+
},
968990
)
969991

970992
app_builder = A2AFastAPIApplication(
@@ -993,6 +1015,7 @@ async def test_json_transport_get_signed_base_and_extended_cards(
9931015
assert result.name == extended_agent_card.name
9941016
assert result.signatures is not None
9951017
assert len(result.signatures) == 1
1018+
assert transport.agent_card is not None
9961019
assert transport.agent_card.name == extended_agent_card.name
9971020
assert transport._needs_extended_card is False
9981021

@@ -1019,7 +1042,13 @@ async def test_rest_transport_get_signed_card(
10191042
private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1())
10201043
public_key = private_key.public_key()
10211044
signer = create_agent_card_signer(
1022-
signing_key=private_key, alg='ES256', kid='testkey'
1045+
signing_key=private_key,
1046+
protected_header={
1047+
'alg': 'ES256',
1048+
'kid': 'testkey',
1049+
'jku': None,
1050+
'typ': 'JOSE',
1051+
},
10231052
)
10241053

10251054
app_builder = A2ARESTFastAPIApplication(
@@ -1048,6 +1077,7 @@ async def test_rest_transport_get_signed_card(
10481077
assert result.name == extended_agent_card.name
10491078
assert result.signatures is not None
10501079
assert len(result.signatures) == 1
1080+
assert transport.agent_card is not None
10511081
assert transport.agent_card.name == extended_agent_card.name
10521082
assert transport._needs_extended_card is False
10531083

@@ -1066,7 +1096,13 @@ async def test_grpc_transport_get_signed_card(
10661096
private_key = asymmetric.ec.generate_private_key(asymmetric.ec.SECP256R1())
10671097
public_key = private_key.public_key()
10681098
signer = create_agent_card_signer(
1069-
signing_key=private_key, alg='ES256', kid='testkey'
1099+
signing_key=private_key,
1100+
protected_header={
1101+
'alg': 'ES256',
1102+
'kid': 'testkey',
1103+
'jku': None,
1104+
'typ': 'JOSE',
1105+
},
10701106
)
10711107

10721108
server = grpc.aio.server()

0 commit comments

Comments
 (0)