Skip to content

Commit d94df5a

Browse files
committed
Fix a typo and update tests with a rename
1 parent 246a719 commit d94df5a

File tree

3 files changed

+36
-110
lines changed

3 files changed

+36
-110
lines changed

src/a2a/client/auth/credentials.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ class InMemoryContextCredentialStore(CredentialService):
2525
"""
2626

2727
def __init__(self) -> None:
28-
# {session_id: {scheme_name: credential}}
2928
self._store: dict[str, dict[str, str]] = {}
3029

3130
async def get_credentials(
@@ -38,7 +37,7 @@ async def get_credentials(
3837
session_id = context.state['sessionId']
3938
return self._store.get(session_id, {}).get(security_scheme_name)
4039

41-
async def set_credential(
40+
async def set_credentials(
4241
self, session_id: str, security_scheme_name: str, credential: str
4342
) -> None:
4443
"""Method to populate the store."""

src/a2a/client/middleware.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
24
from collections.abc import MutableMapping
35
from typing import TYPE_CHECKING, Any
46

57
from pydantic import BaseModel, Field
68

7-
89
if TYPE_CHECKING:
910
from a2a.types import AgentCard
1011

@@ -31,8 +32,8 @@ async def intercept(
3132
method_name: str,
3233
request_payload: dict[str, Any],
3334
http_kwargs: dict[str, Any],
34-
agent_card: 'AgentCard | None',
35-
context: 'ClientCallContext | None',
35+
agent_card: AgentCard | None,
36+
context: ClientCallContext | None,
3637
) -> tuple[dict[str, Any], dict[str, Any]]:
3738
"""
3839
Intercepts a client call before the request is sent.

tests/client/test_auth_middleware.py

Lines changed: 31 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,10 @@
66

77
from a2a.client import A2AClient, ClientCallContext, ClientCallInterceptor
88
from a2a.client.auth import AuthInterceptor, InMemoryContextCredentialStore
9-
from a2a.types import (
10-
APIKeySecurityScheme,
11-
AgentCapabilities,
12-
AgentCard,
13-
AuthorizationCodeOAuthFlow,
14-
In,
15-
OAuth2SecurityScheme,
16-
OAuthFlows,
17-
OpenIdConnectSecurityScheme,
18-
SecurityScheme,
19-
SendMessageRequest,
20-
)
9+
from a2a.types import (AgentCapabilities, AgentCard, APIKeySecurityScheme,
10+
AuthorizationCodeOAuthFlow, In, OAuth2SecurityScheme,
11+
OAuthFlows, OpenIdConnectSecurityScheme, SecurityScheme,
12+
SendMessageRequest)
2113

2214

2315
# A simple mock interceptor for testing basic middleware functionality
@@ -66,7 +58,7 @@ async def test_client_with_simple_interceptor():
6658
'messageId': 'response-msg',
6759
'role': 'agent',
6860
'parts': [],
69-
}, # Example result
61+
},
7062
}
7163
respx.post(test_url).mock(
7264
return_value=httpx.Response(200, json=minimal_success_response)
@@ -97,24 +89,24 @@ async def test_client_with_simple_interceptor():
9789
async def test_in_memory_context_credential_store():
9890
"""
9991
Tests the functionality of the InMemoryContextCredentialStore to ensure
100-
it correctly stores and retrieves credentials based on contextId.
92+
it correctly stores and retrieves credentials based on sessionId.
10193
"""
10294
# Arrange
10395
store = InMemoryContextCredentialStore()
104-
context_id = 'test-context-123'
96+
session_id = 'test-session-123'
10597
scheme_name = 'test-scheme'
10698
credential = 'test-token'
10799

108100
# Act
109-
await store.set_credentials(context_id, scheme_name, credential)
101+
await store.set_credentials(session_id, scheme_name, credential)
110102

111103
# Assert: Successful retrieval
112-
context = ClientCallContext(state={'contextId': context_id})
104+
context = ClientCallContext(state={'sessionId': session_id})
113105
retrieved_credential = await store.get_credentials(scheme_name, context)
114106
assert retrieved_credential == credential
115107

116-
# Assert: Retrieval with wrong context returns None
117-
wrong_context = ClientCallContext(state={'contextId': 'wrong-context'})
108+
# Assert: Retrieval with wrong session ID returns None
109+
wrong_context = ClientCallContext(state={'sessionId': 'wrong-session'})
118110
retrieved_credential_wrong = await store.get_credentials(
119111
scheme_name, wrong_context
120112
)
@@ -124,7 +116,7 @@ async def test_in_memory_context_credential_store():
124116
retrieved_credential_none = await store.get_credentials(scheme_name, None)
125117
assert retrieved_credential_none is None
126118

127-
# Assert: Retrieval with context but no contextId returns None
119+
# Assert: Retrieval with context but no sessionId returns None
128120
empty_context = ClientCallContext(state={})
129121
retrieved_credential_empty = await store.get_credentials(
130122
scheme_name, empty_context
@@ -140,15 +132,21 @@ async def test_auth_interceptor_with_api_key():
140132
"""
141133
# Arrange
142134
test_url = 'http://apikey-agent.com/rpc'
143-
context_id = 'user-session-2'
135+
session_id = 'user-session-2'
144136
scheme_name = 'apiKeyAuth'
145137
api_key = 'secret-api-key'
146138

147139
cred_store = InMemoryContextCredentialStore()
148-
await cred_store.set_credentials(context_id, scheme_name, api_key)
140+
await cred_store.set_credentials(session_id, scheme_name, api_key)
149141

150142
auth_interceptor = AuthInterceptor(credential_service=cred_store)
151143

144+
api_key_scheme_params = {
145+
'type': 'apiKey',
146+
'name': 'X-API-Key',
147+
'in': In.header,
148+
}
149+
152150
agent_card = AgentCard(
153151
url=test_url,
154152
name='ApiKeyBot',
@@ -161,82 +159,11 @@ async def test_auth_interceptor_with_api_key():
161159
security=[{scheme_name: []}],
162160
securitySchemes={
163161
scheme_name: SecurityScheme(
164-
root=APIKeySecurityScheme(
165-
name='X-API-Key', in_=In.header, type='apiKey'
166-
)
162+
root=APIKeySecurityScheme(**api_key_scheme_params)
167163
)
168164
},
169165
)
170166

171-
async with httpx.AsyncClient() as http_client:
172-
client = A2AClient(
173-
httpx_client=http_client,
174-
agent_card=agent_card,
175-
interceptors=[auth_interceptor],
176-
)
177-
178-
respx.post(test_url).mock(return_value=httpx.Response(200, json={}))
179-
180-
# Act
181-
context = ClientCallContext(state={'contextId': context_id})
182-
await client.send_message(
183-
request=SendMessageRequest(
184-
id='1',
185-
params={
186-
'message': {
187-
'messageId': 'msg1',
188-
'role': 'user',
189-
'parts': [],
190-
}
191-
},
192-
),
193-
context=context,
194-
)
195-
196-
# Assert
197-
assert len(respx.calls) == 1
198-
request = respx.calls.last.request
199-
assert 'x-api-key' in request.headers
200-
assert request.headers['x-api-key'] == api_key
201-
202-
203-
@pytest.mark.asyncio
204-
@respx.mock
205-
async def test_auth_interceptor_with_api_key():
206-
"""
207-
Tests the authentication flow with an API key in the header.
208-
"""
209-
# Arrange
210-
test_url = 'http://apikey-agent.com/rpc'
211-
context_id = 'user-session-2'
212-
scheme_name = 'apiKeyAuth'
213-
api_key = 'secret-api-key'
214-
215-
cred_store = InMemoryContextCredentialStore()
216-
await cred_store.set_credentials(context_id, scheme_name, api_key)
217-
218-
auth_interceptor = AuthInterceptor(credential_service=cred_store)
219-
220-
# Use the alias 'in' for instantiation, as confirmed by debug output
221-
api_key_scheme_instance = APIKeySecurityScheme(
222-
name='X-API-Key', **{'in': In.header}
223-
)
224-
225-
agent_card = AgentCard(
226-
url=test_url,
227-
name='ApiKeyBot',
228-
description='A bot that requires an API Key',
229-
version='1.0',
230-
defaultInputModes=[],
231-
defaultOutputModes=[],
232-
skills=[],
233-
capabilities=AgentCapabilities(),
234-
security=[{scheme_name: []}],
235-
securitySchemes={
236-
scheme_name: SecurityScheme(root=api_key_scheme_instance)
237-
},
238-
)
239-
240167
async with httpx.AsyncClient() as http_client:
241168
client = A2AClient(
242169
httpx_client=http_client,
@@ -259,7 +186,7 @@ async def test_auth_interceptor_with_api_key():
259186
)
260187

261188
# Act
262-
context = ClientCallContext(state={'contextId': context_id})
189+
context = ClientCallContext(state={'sessionId': session_id})
263190
await client.send_message(
264191
request=SendMessageRequest(
265192
id='1',
@@ -289,16 +216,15 @@ async def test_auth_interceptor_with_oauth2_scheme():
289216
Ensures it correctly sets the Authorization: Bearer <token> header.
290217
"""
291218
test_url = 'http://oauth-agent.com/rpc'
292-
context_id = 'user-session-oauth'
219+
session_id = 'user-session-oauth'
293220
scheme_name = 'myOAuthScheme'
294221
access_token = 'secret-oauth-access-token'
295222

296223
cred_store = InMemoryContextCredentialStore()
297-
await cred_store.set_credentials(context_id, scheme_name, access_token)
224+
await cred_store.set_credentials(session_id, scheme_name, access_token)
298225

299226
auth_interceptor = AuthInterceptor(credential_service=cred_store)
300227

301-
# Define a minimal OAuth2 flow
302228
oauth_flows = OAuthFlows(
303229
authorizationCode=AuthorizationCodeOAuthFlow(
304230
authorizationUrl='http://provider.com/auth',
@@ -346,7 +272,7 @@ async def test_auth_interceptor_with_oauth2_scheme():
346272
)
347273

348274
# Act
349-
context = ClientCallContext(state={'contextId': context_id})
275+
context = ClientCallContext(state={'sessionId': session_id})
350276
await client.send_message(
351277
request=SendMessageRequest(
352278
id='oauth_test_1',
@@ -377,12 +303,12 @@ async def test_auth_interceptor_with_oidc_scheme():
377303
"""
378304
# Arrange
379305
test_url = 'http://oidc-agent.com/rpc'
380-
context_id = 'user-session-oidc'
306+
session_id = 'user-session-oidc'
381307
scheme_name = 'myOidcScheme'
382-
id_token = 'secret-oidc-id-token' # Or access_token
308+
id_token = 'secret-oidc-id-token'
383309

384310
cred_store = InMemoryContextCredentialStore()
385-
await cred_store.set_credentials(context_id, scheme_name, id_token)
311+
await cred_store.set_credentials(session_id, scheme_name, id_token)
386312

387313
auth_interceptor = AuthInterceptor(credential_service=cred_store)
388314

@@ -397,7 +323,7 @@ async def test_auth_interceptor_with_oidc_scheme():
397323
capabilities=AgentCapabilities(),
398324
security=[
399325
{scheme_name: []}
400-
], # Security requirement referencing the scheme
326+
],
401327
securitySchemes={
402328
scheme_name: SecurityScheme(
403329
root=OpenIdConnectSecurityScheme(
@@ -430,7 +356,7 @@ async def test_auth_interceptor_with_oidc_scheme():
430356
)
431357

432358
# Act
433-
context = ClientCallContext(state={'contextId': context_id})
359+
context = ClientCallContext(state={'sessionId': session_id})
434360
await client.send_message(
435361
request=SendMessageRequest(
436362
id='oidc_test_1',
@@ -449,4 +375,4 @@ async def test_auth_interceptor_with_oidc_scheme():
449375
assert len(respx.calls) == 1
450376
request_sent = respx.calls.last.request
451377
assert 'Authorization' in request_sent.headers
452-
assert request_sent.headers['Authorization'] == f'Bearer {id_token}'
378+
assert request_sent.headers['Authorization'] == f'Bearer {id_token}'

0 commit comments

Comments
 (0)