Skip to content

Commit ce5a6e1

Browse files
committed
common logic for all bearer based schemes
1 parent 18088e1 commit ce5a6e1

File tree

2 files changed

+82
-17
lines changed

2 files changed

+82
-17
lines changed

src/a2a/client/auth/interceptor.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from a2a.client.auth.credentials import CredentialService
77
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
8-
from a2a.types import AgentCard, APIKeySecurityScheme, HTTPAuthSecurityScheme, In, OAuth2SecurityScheme
8+
from a2a.types import AgentCard, APIKeySecurityScheme, HTTPAuthSecurityScheme, In, OAuth2SecurityScheme, OpenIdConnectSecurityScheme
99

1010
logger = logging.getLogger(__name__)
1111

@@ -31,36 +31,35 @@ async def intercept(
3131
return request_payload, http_kwargs
3232

3333
for requirement in agent_card.security:
34-
for scheme_name in requirement: # Iterate through scheme names in the requirement
34+
for scheme_name in requirement:
3535
credential = await self._credential_service.get_credentials(
3636
scheme_name, context
3737
)
3838
if credential and scheme_name in agent_card.securitySchemes:
3939
scheme_def_union = agent_card.securitySchemes[scheme_name]
40-
if not scheme_def_union:
41-
continue
42-
scheme_def = scheme_def_union.root # SecurityScheme is a RootModel
40+
if not scheme_def_union:
41+
continue
42+
scheme_def = scheme_def_union.root
4343

4444
headers = http_kwargs.get('headers', {})
4545

46-
if isinstance(scheme_def, HTTPAuthSecurityScheme):
47-
if scheme_def.scheme.lower() == 'bearer':
48-
headers['Authorization'] = f"Bearer {credential}"
49-
logger.debug(f"Added HTTP Bearer Auth for scheme '{scheme_name}'.")
50-
http_kwargs['headers'] = headers
51-
return request_payload, http_kwargs
52-
elif isinstance(scheme_def, OAuth2SecurityScheme): # New condition for OAuth2
53-
# For OAuth2, the credential obtained is the access token, used as a Bearer token.
46+
is_bearer_scheme = False
47+
if isinstance(scheme_def, HTTPAuthSecurityScheme) and scheme_def.scheme.lower() == 'bearer':
48+
is_bearer_scheme = True
49+
elif isinstance(scheme_def, (OAuth2SecurityScheme, OpenIdConnectSecurityScheme)):
50+
is_bearer_scheme = True
51+
52+
if is_bearer_scheme:
5453
headers['Authorization'] = f"Bearer {credential}"
55-
logger.debug(f"Added OAuth2 Bearer token for scheme '{scheme_name}'.")
54+
logger.debug(f"Added Bearer token for scheme '{scheme_name}' (type: {scheme_def.type}).")
5655
http_kwargs['headers'] = headers
5756
return request_payload, http_kwargs
5857
elif isinstance(scheme_def, APIKeySecurityScheme):
59-
if scheme_def.in_ == In.header: # Use In.header enum member
58+
if scheme_def.in_ == In.header:
6059
headers[scheme_def.name] = credential
6160
logger.debug(f"Added API Key Header for scheme '{scheme_name}'.")
6261
http_kwargs['headers'] = headers
6362
return request_payload, http_kwargs
64-
# Note: API keys in query or cookie are not handled by this interceptor modification.
63+
# Note: API keys in query or cookie are not handled here.
6564

6665
return request_payload, http_kwargs

tests/client/test_auth_middleware.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
In,
2626
OAuth2SecurityScheme,
2727
OAuthFlows,
28+
OpenIdConnectSecurityScheme,
2829
SecurityScheme,
2930
SendMessageRequest,
3031
)
@@ -302,4 +303,69 @@ async def test_auth_interceptor_with_oauth2_scheme():
302303
assert len(respx.calls) == 1
303304
request_sent = respx.calls.last.request
304305
assert "Authorization" in request_sent.headers
305-
assert request_sent.headers["Authorization"] == f"Bearer {access_token}"
306+
assert request_sent.headers["Authorization"] == f"Bearer {access_token}"
307+
308+
@pytest.mark.asyncio
309+
@respx.mock
310+
async def test_auth_interceptor_with_oidc_scheme():
311+
"""
312+
Tests the AuthInterceptor with an OpenIdConnectSecurityScheme.
313+
Ensures it correctly sets the Authorization: Bearer <token> header.
314+
"""
315+
# Arrange
316+
test_url = "http://oidc-agent.com/rpc"
317+
context_id = "user-session-oidc"
318+
scheme_name = "myOidcScheme"
319+
id_token = "secret-oidc-id-token" # Or access_token
320+
321+
cred_store = InMemoryContextCredentialStore()
322+
await cred_store.set_credentials(context_id, scheme_name, id_token)
323+
324+
auth_interceptor = AuthInterceptor(credential_service=cred_store)
325+
326+
agent_card = AgentCard(
327+
url=test_url,
328+
name="OidcBot",
329+
description="A bot that uses OpenID Connect",
330+
version="1.0",
331+
defaultInputModes=[],
332+
defaultOutputModes=[],
333+
skills=[],
334+
capabilities=AgentCapabilities(),
335+
security=[{scheme_name: []}], # Security requirement referencing the scheme
336+
securitySchemes={
337+
scheme_name: SecurityScheme(
338+
root=OpenIdConnectSecurityScheme(
339+
type="openIdConnect",
340+
openIdConnectUrl="http://provider.com/.well-known/openid-configuration"
341+
)
342+
)
343+
},
344+
)
345+
346+
async with httpx.AsyncClient() as http_client:
347+
client = A2AClient(
348+
httpx_client=http_client,
349+
agent_card=agent_card,
350+
interceptors=[auth_interceptor]
351+
)
352+
353+
minimal_success_response = {
354+
"jsonrpc": "2.0",
355+
"id": "oidc_test_1",
356+
"result": {"kind": "message", "messageId": "response-msg-oidc", "role": "agent", "parts": []}
357+
}
358+
respx.post(test_url).mock(return_value=httpx.Response(200, json=minimal_success_response))
359+
360+
# Act
361+
context = ClientCallContext(state={"contextId": context_id})
362+
await client.send_message(
363+
request=SendMessageRequest(id="oidc_test_1", params={"message": {"messageId": "msg-oidc", "role": "user", "parts": []}}),
364+
context=context
365+
)
366+
367+
# Assert
368+
assert len(respx.calls) == 1
369+
request_sent = respx.calls.last.request
370+
assert "Authorization" in request_sent.headers
371+
assert request_sent.headers["Authorization"] == f"Bearer {id_token}"

0 commit comments

Comments
 (0)