Skip to content

Commit 18088e1

Browse files
committed
OAuth test
1 parent 9729217 commit 18088e1

File tree

2 files changed

+93
-16
lines changed

2 files changed

+93
-16
lines changed

src/a2a/client/auth/interceptor.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from a2a.client.auth.credentials import CredentialService
77
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
8-
from a2a.types import AgentCard, APIKeySecurityScheme, HTTPAuthSecurityScheme
8+
from a2a.types import AgentCard, APIKeySecurityScheme, HTTPAuthSecurityScheme, In, OAuth2SecurityScheme
99

1010
logger = logging.getLogger(__name__)
1111

@@ -27,35 +27,40 @@ async def intercept(
2727
agent_card: AgentCard | None,
2828
context: ClientCallContext | None,
2929
) -> tuple[dict[str, Any], dict[str, Any]]:
30-
"""
31-
Adds authentication headers to the request if credentials can be found.
32-
"""
3330
if not agent_card or not agent_card.security or not agent_card.securitySchemes:
3431
return request_payload, http_kwargs
3532

3633
for requirement in agent_card.security:
37-
for scheme_name in requirement:
34+
for scheme_name in requirement: # Iterate through scheme names in the requirement
3835
credential = await self._credential_service.get_credentials(
3936
scheme_name, context
4037
)
4138
if credential and scheme_name in agent_card.securitySchemes:
42-
scheme_def = agent_card.securitySchemes[scheme_name].root
39+
scheme_def_union = agent_card.securitySchemes[scheme_name]
40+
if not scheme_def_union:
41+
continue
42+
scheme_def = scheme_def_union.root # SecurityScheme is a RootModel
43+
4344
headers = http_kwargs.get('headers', {})
4445

4546
if isinstance(scheme_def, HTTPAuthSecurityScheme):
46-
headers['Authorization'] = f"{scheme_def.scheme} {credential}"
47+
if scheme_def.scheme.lower() == 'bearer':
48+
headers['Authorization'] = f"Bearer {credential}"
49+
logger.debug(f"Added HTTP Bearer Auth for scheme '{scheme_name}'.")
50+
http_kwargs['headers'] = headers
51+
return request_payload, http_kwargs
52+
elif isinstance(scheme_def, OAuth2SecurityScheme): # New condition for OAuth2
53+
# For OAuth2, the credential obtained is the access token, used as a Bearer token.
54+
headers['Authorization'] = f"Bearer {credential}"
55+
logger.debug(f"Added OAuth2 Bearer token for scheme '{scheme_name}'.")
4756
http_kwargs['headers'] = headers
48-
logger.debug(f"Added HTTP Auth for scheme '{scheme_name}'.")
4957
return request_payload, http_kwargs
5058
elif isinstance(scheme_def, APIKeySecurityScheme):
51-
if scheme_def.in_ == 'header':
59+
if scheme_def.in_ == In.header: # Use In.header enum member
5260
headers[scheme_def.name] = credential
53-
http_kwargs['headers'] = headers
5461
logger.debug(f"Added API Key Header for scheme '{scheme_name}'.")
62+
http_kwargs['headers'] = headers
5563
return request_payload, http_kwargs
56-
else:
57-
logger.warning(
58-
f"API Key in '{scheme_def.in_}' not supported by this interceptor."
59-
)
60-
64+
# Note: API keys in query or cookie are not handled by this interceptor modification.
65+
6166
return request_payload, http_kwargs

tests/client/test_auth_middleware.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@
2020
AgentCard,
2121
AgentCapabilities,
2222
APIKeySecurityScheme,
23+
AuthorizationCodeOAuthFlow,
2324
HTTPAuthSecurityScheme,
2425
In,
26+
OAuth2SecurityScheme,
27+
OAuthFlows,
2528
SecurityScheme,
2629
SendMessageRequest,
2730
)
@@ -230,4 +233,73 @@ async def test_auth_interceptor_with_api_key():
230233
assert len(respx.calls) == 1
231234
request = respx.calls.last.request
232235
assert "x-api-key" in request.headers
233-
assert request.headers["x-api-key"] == api_key
236+
assert request.headers["x-api-key"] == api_key
237+
238+
239+
@pytest.mark.asyncio
240+
@respx.mock
241+
async def test_auth_interceptor_with_oauth2_scheme():
242+
"""
243+
Tests the AuthInterceptor with an OAuth2 security scheme defined in AgentCard.
244+
Ensures it correctly sets the Authorization: Bearer <token> header.
245+
"""
246+
test_url = "http://oauth-agent.com/rpc"
247+
context_id = "user-session-oauth"
248+
scheme_name = "myOAuthScheme"
249+
access_token = "secret-oauth-access-token"
250+
251+
cred_store = InMemoryContextCredentialStore()
252+
await cred_store.set_credentials(context_id, scheme_name, access_token)
253+
254+
auth_interceptor = AuthInterceptor(credential_service=cred_store)
255+
256+
# Define a minimal OAuth2 flow
257+
oauth_flows = OAuthFlows(
258+
authorizationCode=AuthorizationCodeOAuthFlow(
259+
authorizationUrl="http://provider.com/auth",
260+
tokenUrl="http://provider.com/token",
261+
scopes={"read": "Read scope"}
262+
)
263+
)
264+
265+
agent_card = AgentCard(
266+
url=test_url,
267+
name="OAuthBot",
268+
description="A bot that uses OAuth2",
269+
version="1.0",
270+
defaultInputModes=[],
271+
defaultOutputModes=[],
272+
skills=[],
273+
capabilities=AgentCapabilities(),
274+
security=[{scheme_name: ["read"]}],
275+
securitySchemes={
276+
scheme_name: SecurityScheme(root=OAuth2SecurityScheme(type="oauth2", flows=oauth_flows))
277+
},
278+
)
279+
280+
async with httpx.AsyncClient() as http_client:
281+
client = A2AClient(
282+
httpx_client=http_client,
283+
agent_card=agent_card,
284+
interceptors=[auth_interceptor]
285+
)
286+
287+
minimal_success_response = {
288+
"jsonrpc": "2.0",
289+
"id": "oauth_test_1",
290+
"result": {"kind": "message", "messageId": "response-msg-oauth", "role": "agent", "parts": []}
291+
}
292+
respx.post(test_url).mock(return_value=httpx.Response(200, json=minimal_success_response))
293+
294+
# Act
295+
context = ClientCallContext(state={"contextId": context_id})
296+
await client.send_message(
297+
request=SendMessageRequest(id="oauth_test_1", params={"message": {"messageId": "msg-oauth", "role": "user", "parts": []}}),
298+
context=context
299+
)
300+
301+
# Assert
302+
assert len(respx.calls) == 1
303+
request_sent = respx.calls.last.request
304+
assert "Authorization" in request_sent.headers
305+
assert request_sent.headers["Authorization"] == f"Bearer {access_token}"

0 commit comments

Comments
 (0)