Skip to content

Commit 9729217

Browse files
committed
tests working
1 parent 5e7d418 commit 9729217

File tree

10 files changed

+1000
-516
lines changed

10 files changed

+1000
-516
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ dev = [
7070
"pytest-asyncio>=0.26.0",
7171
"pytest-cov>=6.1.1",
7272
"pytest-mock>=3.14.0",
73+
"respx>=0.20.2",
7374
"ruff>=0.11.6",
7475
"uv-dynamic-versioning>=0.8.2",
7576
]

src/a2a/client/__init__.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
1+
# a2a/client/__init__.py
2+
13
"""Client-side components for interacting with an A2A agent."""
24

5+
from a2a.client.auth import (AuthInterceptor, CredentialService,
6+
InMemoryContextCredentialStore)
37
from a2a.client.client import A2ACardResolver, A2AClient
4-
from a2a.client.errors import (
5-
A2AClientError,
6-
A2AClientHTTPError,
7-
A2AClientJSONError,
8-
)
8+
from a2a.client.errors import (A2AClientError, A2AClientHTTPError,
9+
A2AClientJSONError)
910
from a2a.client.helpers import create_text_message_object
10-
11+
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1112

1213
__all__ = [
1314
'A2ACardResolver',
1415
'A2AClient',
1516
'A2AClientError',
1617
'A2AClientHTTPError',
1718
'A2AClientJSONError',
19+
'AuthInterceptor',
20+
'ClientCallContext',
21+
'ClientCallInterceptor',
22+
'CredentialService',
23+
'InMemoryContextCredentialStore',
1824
'create_text_message_object',
19-
]
25+
]

src/a2a/client/auth/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Client-side authentication components for the A2A Python SDK."""
2+
3+
from .credentials import CredentialService, InMemoryContextCredentialStore
4+
from .interceptor import AuthInterceptor
5+
6+
__all__ = [
7+
'CredentialService',
8+
'InMemoryContextCredentialStore',
9+
'AuthInterceptor',
10+
]

src/a2a/client/auth/credentials.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# a2a/client/auth/credentials.py
2+
3+
from abc import ABC, abstractmethod
4+
5+
from a2a.client.middleware import ClientCallContext
6+
7+
8+
class CredentialService(ABC):
9+
"""An abstract service for retrieving credentials."""
10+
11+
@abstractmethod
12+
async def get_credentials(
13+
self,
14+
security_scheme_name: str,
15+
context: ClientCallContext | None,
16+
) -> str | None:
17+
"""
18+
Retrieves a credential (e.g., token) for a security scheme.
19+
"""
20+
pass
21+
22+
23+
class InMemoryContextCredentialStore(CredentialService):
24+
"""
25+
A simple in-memory store for context-keyed credentials.
26+
27+
This class uses the 'contextId' from the ClientCallContext state to
28+
store and retrieve credentials, providing a simple, user-specific
29+
credential mechanism without requiring a full user authentication system.
30+
"""
31+
32+
def __init__(self):
33+
# {context_id: {scheme_name: credential}}
34+
self._store: dict[str, dict[str, str]] = {}
35+
36+
async def get_credentials(
37+
self,
38+
security_scheme_name: str,
39+
context: ClientCallContext | None,
40+
) -> str | None:
41+
if not context or 'contextId' not in context.state:
42+
return None
43+
context_id = context.state['contextId']
44+
return self._store.get(context_id, {}).get(security_scheme_name)
45+
46+
async def set_credentials(
47+
self, context_id: str, security_scheme_name: str, credential: str
48+
):
49+
"""Method to populate the store."""
50+
if context_id not in self._store:
51+
self._store[context_id] = {}
52+
self._store[context_id][security_scheme_name] = credential

src/a2a/client/auth/interceptor.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# a2a/client/auth/interceptor.py
2+
3+
import logging
4+
from typing import Any
5+
6+
from a2a.client.auth.credentials import CredentialService
7+
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
8+
from a2a.types import AgentCard, APIKeySecurityScheme, HTTPAuthSecurityScheme
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class AuthInterceptor(ClientCallInterceptor):
14+
"""
15+
An interceptor that automatically adds authentication details to requests
16+
based on the agent's security schemes.
17+
"""
18+
19+
def __init__(self, credential_service: CredentialService):
20+
self._credential_service = credential_service
21+
22+
async def intercept(
23+
self,
24+
method_name: str,
25+
request_payload: dict[str, Any],
26+
http_kwargs: dict[str, Any],
27+
agent_card: AgentCard | None,
28+
context: ClientCallContext | None,
29+
) -> tuple[dict[str, Any], dict[str, Any]]:
30+
"""
31+
Adds authentication headers to the request if credentials can be found.
32+
"""
33+
if not agent_card or not agent_card.security or not agent_card.securitySchemes:
34+
return request_payload, http_kwargs
35+
36+
for requirement in agent_card.security:
37+
for scheme_name in requirement:
38+
credential = await self._credential_service.get_credentials(
39+
scheme_name, context
40+
)
41+
if credential and scheme_name in agent_card.securitySchemes:
42+
scheme_def = agent_card.securitySchemes[scheme_name].root
43+
headers = http_kwargs.get('headers', {})
44+
45+
if isinstance(scheme_def, HTTPAuthSecurityScheme):
46+
headers['Authorization'] = f"{scheme_def.scheme} {credential}"
47+
http_kwargs['headers'] = headers
48+
logger.debug(f"Added HTTP Auth for scheme '{scheme_name}'.")
49+
return request_payload, http_kwargs
50+
elif isinstance(scheme_def, APIKeySecurityScheme):
51+
if scheme_def.in_ == 'header':
52+
headers[scheme_def.name] = credential
53+
http_kwargs['headers'] = headers
54+
logger.debug(f"Added API Key Header for scheme '{scheme_name}'.")
55+
return request_payload, http_kwargs
56+
else:
57+
logger.warning(
58+
f"API Key in '{scheme_def.in_}' not supported by this interceptor."
59+
)
60+
61+
return request_payload, http_kwargs

src/a2a/client/auth/user.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Authenticated user information."""
2+
3+
from abc import ABC, abstractmethod
4+
5+
6+
class User(ABC):
7+
"""A representation of an authenticated user."""
8+
9+
@property
10+
@abstractmethod
11+
def is_authenticated(self) -> bool:
12+
"""Returns whether the current user is authenticated."""
13+
14+
@property
15+
@abstractmethod
16+
def user_name(self) -> str:
17+
"""Returns the user name of the current user."""
18+
19+
20+
class UnauthenticatedUser(User):
21+
"""A representation that no user has been authenticated in the request."""
22+
23+
@property
24+
def is_authenticated(self):
25+
return False
26+
27+
@property
28+
def user_name(self) -> str:
29+
return ''

0 commit comments

Comments
 (0)