diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 14b4fe11..5f4c150c 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -45,6 +45,7 @@ opensource protoc pyi pyversions +respx resub socio sse diff --git a/.vscode/launch.json b/.vscode/launch.json index 37651238..6adb30d5 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,12 @@ "PYTHONPATH": "${workspaceFolder}" }, "cwd": "${workspaceFolder}/examples/helloworld", - "args": ["--host", "localhost", "--port", "9999"] + "args": [ + "--host", + "localhost", + "--port", + "9999" + ] }, { "name": "Debug Currency Agent", @@ -25,7 +30,24 @@ "PYTHONPATH": "${workspaceFolder}" }, "cwd": "${workspaceFolder}/examples/langgraph", - "args": ["--host", "localhost", "--port", "10000"] + "args": [ + "--host", + "localhost", + "--port", + "10000" + ] + }, + { + "name": "Pytest All", + "type": "debugpy", + "request": "launch", + "module": "pytest", + "args": [ + "-v", + "-s" + ], + "console": "integratedTerminal", + "justMyCode": true } ] -} +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 56fb1e34..c400faa3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,7 @@ dev = [ "pytest-asyncio>=0.26.0", "pytest-cov>=6.1.1", "pytest-mock>=3.14.0", + "respx>=0.20.2", "ruff>=0.11.6", "uv-dynamic-versioning>=0.8.2", "types-protobuf", diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index e91c9eb7..3d673b31 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -1,5 +1,10 @@ """Client-side components for interacting with an A2A agent.""" +from a2a.client.auth import ( + AuthInterceptor, + CredentialService, + InMemoryContextCredentialStore, +) from a2a.client.client import A2ACardResolver, A2AClient from a2a.client.errors import ( A2AClientError, @@ -8,6 +13,7 @@ ) from a2a.client.grpc_client import A2AGrpcClient from a2a.client.helpers import create_text_message_object +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor __all__ = [ @@ -17,5 +23,10 @@ 'A2AClientHTTPError', 'A2AClientJSONError', 'A2AGrpcClient', + 'AuthInterceptor', + 'ClientCallContext', + 'ClientCallInterceptor', + 'CredentialService', + 'InMemoryContextCredentialStore', 'create_text_message_object', ] diff --git a/src/a2a/client/auth/__init__.py b/src/a2a/client/auth/__init__.py new file mode 100644 index 00000000..8efe65fc --- /dev/null +++ b/src/a2a/client/auth/__init__.py @@ -0,0 +1,14 @@ +"""Client-side authentication components for the A2A Python SDK.""" + +from a2a.client.auth.credentials import ( + CredentialService, + InMemoryContextCredentialStore, +) +from a2a.client.auth.interceptor import AuthInterceptor + + +__all__ = [ + 'AuthInterceptor', + 'CredentialService', + 'InMemoryContextCredentialStore', +] diff --git a/src/a2a/client/auth/credentials.py b/src/a2a/client/auth/credentials.py new file mode 100644 index 00000000..11f32370 --- /dev/null +++ b/src/a2a/client/auth/credentials.py @@ -0,0 +1,55 @@ +from abc import ABC, abstractmethod + +from a2a.client.middleware import ClientCallContext + + +class CredentialService(ABC): + """An abstract service for retrieving credentials.""" + + @abstractmethod + async def get_credentials( + self, + security_scheme_name: str, + context: ClientCallContext | None, + ) -> str | None: + """ + Retrieves a credential (e.g., token) for a security scheme. + """ + + +class InMemoryContextCredentialStore(CredentialService): + """A simple in-memory store for session-keyed credentials. + + This class uses the 'sessionId' from the ClientCallContext state to + store and retrieve credentials... + """ + + def __init__(self) -> None: + self._store: dict[str, dict[str, str]] = {} + + async def get_credentials( + self, + security_scheme_name: str, + context: ClientCallContext | None, + ) -> str | None: + """Retrieves credentials from the in-memory store. + + Args: + security_scheme_name: The name of the security scheme. + context: The client call context. + + Returns: + The credential string, or None if not found. + """ + if not context or 'sessionId' not in context.state: + return None + session_id = context.state['sessionId'] + return self._store.get(session_id, {}).get(security_scheme_name) + + async def set_credentials( + self, session_id: str, security_scheme_name: str, credential: str + ) -> None: + """Method to populate the store.""" + if session_id not in self._store: + self._store[session_id] = {} + self._store[session_id][security_scheme_name] = credential diff --git a/src/a2a/client/auth/interceptor.py b/src/a2a/client/auth/interceptor.py new file mode 100644 index 00000000..a164f135 --- /dev/null +++ b/src/a2a/client/auth/interceptor.py @@ -0,0 +1,93 @@ +import logging # noqa: I001 +from typing import Any + +from a2a.client.auth.credentials import CredentialService +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.types import ( + AgentCard, + APIKeySecurityScheme, + HTTPAuthSecurityScheme, + In, + OAuth2SecurityScheme, + OpenIdConnectSecurityScheme, +) + +logger = logging.getLogger(__name__) + + +class AuthInterceptor(ClientCallInterceptor): + """An interceptor that automatically adds authentication details to requests. + + Based on the agent's security schemes. + """ + + def __init__(self, credential_service: CredentialService): + self._credential_service = credential_service + + async def intercept( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any], + agent_card: AgentCard | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Applies authentication headers to the request if credentials are available.""" + if ( + agent_card is None + or agent_card.security is None + or agent_card.securitySchemes is None + ): + return request_payload, http_kwargs + + for requirement in agent_card.security: + for scheme_name in requirement: + credential = await self._credential_service.get_credentials( + scheme_name, context + ) + if credential and scheme_name in agent_card.securitySchemes: + scheme_def_union = agent_card.securitySchemes.get( + scheme_name + ) + if not scheme_def_union: + continue + scheme_def = scheme_def_union.root + + headers = http_kwargs.get('headers', {}) + + match scheme_def: + # Case 1a: HTTP Bearer scheme with an if guard + case HTTPAuthSecurityScheme() if ( + scheme_def.scheme.lower() == 'bearer' + ): + headers['Authorization'] = f'Bearer {credential}' + logger.debug( + f"Added Bearer token for scheme '{scheme_name}' (type: {scheme_def.type})." + ) + http_kwargs['headers'] = headers + return request_payload, http_kwargs + + # Case 1b: OAuth2 and OIDC schemes, which are implicitly Bearer + case ( + OAuth2SecurityScheme() + | OpenIdConnectSecurityScheme() + ): + headers['Authorization'] = f'Bearer {credential}' + logger.debug( + f"Added Bearer token for scheme '{scheme_name}' (type: {scheme_def.type})." + ) + http_kwargs['headers'] = headers + return request_payload, http_kwargs + + # Case 2: API Key in Header + case APIKeySecurityScheme(in_=In.header): + headers[scheme_def.name] = credential + logger.debug( + f"Added API Key Header for scheme '{scheme_name}'." + ) + http_kwargs['headers'] = headers + return request_payload, http_kwargs + + # Note: Other cases like API keys in query/cookie are not handled and will be skipped. + + return request_payload, http_kwargs diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 5bcc36f1..e29ef8a7 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -11,6 +11,7 @@ from pydantic import ValidationError from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError +from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.types import ( AgentCard, CancelTaskRequest, @@ -129,6 +130,7 @@ def __init__( httpx_client: httpx.AsyncClient, agent_card: AgentCard | None = None, url: str | None = None, + interceptors: list[ClientCallInterceptor] | None = None, ): """Initializes the A2AClient. @@ -138,6 +140,7 @@ def __init__( httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient). agent_card: The agent card object. If provided, `url` is taken from `agent_card.url`. url: The direct URL to the agent's A2A RPC endpoint. Required if `agent_card` is None. + interceptors: An optional list of client call interceptors to apply to requests. Raises: ValueError: If neither `agent_card` nor `url` is provided. @@ -150,6 +153,32 @@ def __init__( raise ValueError('Must provide either agent_card or url') self.httpx_client = httpx_client + self.agent_card = agent_card + self.interceptors = interceptors or [] + + async def _apply_interceptors( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any] | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Applies all registered interceptors to the request.""" + final_http_kwargs = http_kwargs or {} + final_request_payload = request_payload + + for interceptor in self.interceptors: + ( + final_request_payload, + final_http_kwargs, + ) = await interceptor.intercept( + method_name, + final_request_payload, + final_http_kwargs, + self.agent_card, + context, + ) + return final_request_payload, final_http_kwargs @staticmethod async def get_client_from_agent_card_url( @@ -191,6 +220,7 @@ async def send_message( request: SendMessageRequest, *, http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, ) -> SendMessageResponse: """Sends a non-streaming message request to the agent. @@ -198,6 +228,7 @@ async def send_message( request: The `SendMessageRequest` object containing the message and configuration. http_kwargs: Optional dictionary of keyword arguments to pass to the underlying httpx.post request. + context: The client call context. Returns: A `SendMessageResponse` object containing the agent's response (Task or Message) or an error. @@ -209,18 +240,22 @@ async def send_message( if not request.id: request.id = str(uuid4()) - return SendMessageResponse( - **await self._send_request( - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - ) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'message/send', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, ) + response_data = await self._send_request(payload, modified_kwargs) + return SendMessageResponse.model_validate(response_data) async def send_message_streaming( self, request: SendStreamingMessageRequest, *, http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, ) -> AsyncGenerator[SendStreamingMessageResponse]: """Sends a streaming message request to the agent and yields responses as they arrive. @@ -230,6 +265,7 @@ async def send_message_streaming( request: The `SendStreamingMessageRequest` object containing the message and configuration. http_kwargs: Optional dictionary of keyword arguments to pass to the underlying httpx.post request. A default `timeout=None` is set but can be overridden. + context: The client call context. Yields: `SendStreamingMessageResponse` objects as they are received in the SSE stream. @@ -242,22 +278,28 @@ async def send_message_streaming( if not request.id: request.id = str(uuid4()) - # Default to no timeout for streaming, can be overridden by http_kwargs - http_kwargs_with_timeout: dict[str, Any] = { - 'timeout': None, - **(http_kwargs or {}), - } + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'message/stream', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + + modified_kwargs.setdefault('timeout', None) async with aconnect_sse( self.httpx_client, 'POST', self.url, - json=request.model_dump(mode='json', exclude_none=True), - **http_kwargs_with_timeout, + json=payload, + **modified_kwargs, ) as event_source: try: async for sse in event_source.aiter_sse(): - yield SendStreamingMessageResponse(**json.loads(sse.data)) + yield SendStreamingMessageResponse.model_validate( + json.loads(sse.data) + ) except SSEError as e: raise A2AClientHTTPError( 400, @@ -309,6 +351,7 @@ async def get_task( request: GetTaskRequest, *, http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, ) -> GetTaskResponse: """Retrieves the current state and history of a specific task. @@ -316,6 +359,7 @@ async def get_task( request: The `GetTaskRequest` object specifying the task ID and history length. http_kwargs: Optional dictionary of keyword arguments to pass to the underlying httpx.post request. + context: The client call context. Returns: A `GetTaskResponse` object containing the Task or an error. @@ -327,18 +371,22 @@ async def get_task( if not request.id: request.id = str(uuid4()) - return GetTaskResponse( - **await self._send_request( - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - ) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/get', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, ) + response_data = await self._send_request(payload, modified_kwargs) + return GetTaskResponse.model_validate(response_data) async def cancel_task( self, request: CancelTaskRequest, *, http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, ) -> CancelTaskResponse: """Requests the agent to cancel a specific task. @@ -346,6 +394,7 @@ async def cancel_task( request: The `CancelTaskRequest` object specifying the task ID. http_kwargs: Optional dictionary of keyword arguments to pass to the underlying httpx.post request. + context: The client call context. Returns: A `CancelTaskResponse` object containing the updated Task with canceled status or an error. @@ -357,18 +406,22 @@ async def cancel_task( if not request.id: request.id = str(uuid4()) - return CancelTaskResponse( - **await self._send_request( - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - ) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/cancel', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, ) + response_data = await self._send_request(payload, modified_kwargs) + return CancelTaskResponse.model_validate(response_data) async def set_task_callback( self, request: SetTaskPushNotificationConfigRequest, *, http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, ) -> SetTaskPushNotificationConfigResponse: """Sets or updates the push notification configuration for a specific task. @@ -376,6 +429,7 @@ async def set_task_callback( request: The `SetTaskPushNotificationConfigRequest` object specifying the task ID and configuration. http_kwargs: Optional dictionary of keyword arguments to pass to the underlying httpx.post request. + context: The client call context. Returns: A `SetTaskPushNotificationConfigResponse` object containing the confirmation or an error. @@ -387,11 +441,16 @@ async def set_task_callback( if not request.id: request.id = str(uuid4()) - return SetTaskPushNotificationConfigResponse( - **await self._send_request( - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - ) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/set', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return SetTaskPushNotificationConfigResponse.model_validate( + response_data ) async def get_task_callback( @@ -399,6 +458,7 @@ async def get_task_callback( request: GetTaskPushNotificationConfigRequest, *, http_kwargs: dict[str, Any] | None = None, + context: ClientCallContext | None = None, ) -> GetTaskPushNotificationConfigResponse: """Retrieves the push notification configuration for a specific task. @@ -406,6 +466,7 @@ async def get_task_callback( request: The `GetTaskPushNotificationConfigRequest` object specifying the task ID. http_kwargs: Optional dictionary of keyword arguments to pass to the underlying httpx.post request. + context: The client call context. Returns: A `GetTaskPushNotificationConfigResponse` object containing the configuration or an error. @@ -417,9 +478,14 @@ async def get_task_callback( if not request.id: request.id = str(uuid4()) - return GetTaskPushNotificationConfigResponse( - **await self._send_request( - request.model_dump(mode='json', exclude_none=True), - http_kwargs, - ) + # Apply interceptors before sending + payload, modified_kwargs = await self._apply_interceptors( + 'tasks/pushNotificationConfig/get', + request.model_dump(mode='json', exclude_none=True), + http_kwargs, + context, + ) + response_data = await self._send_request(payload, modified_kwargs) + return GetTaskPushNotificationConfigResponse.model_validate( + response_data ) diff --git a/src/a2a/client/middleware.py b/src/a2a/client/middleware.py new file mode 100644 index 00000000..73ada982 --- /dev/null +++ b/src/a2a/client/middleware.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import MutableMapping # noqa: TC003 +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, Field + + +if TYPE_CHECKING: + from a2a.types import AgentCard + + +class ClientCallContext(BaseModel): + """A context passed with each client call, allowing for call-specific. + + configuration and data passing. Such as authentication details or + request deadlines. + """ + + state: MutableMapping[str, Any] = Field(default_factory=dict) + + +class ClientCallInterceptor(ABC): + """An abstract base class for client-side call interceptors. + + Interceptors can inspect and modify requests before they are sent, + which is ideal for concerns like authentication, logging, or tracing. + """ + + @abstractmethod + async def intercept( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any], + agent_card: AgentCard | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Intercepts a client call before the request is sent. + + Args: + method_name: The name of the RPC method (e.g., 'message/send'). + request_payload: The JSON RPC request payload dictionary. + http_kwargs: The keyword arguments for the httpx request. + agent_card: The AgentCard associated with the client. + context: The ClientCallContext for this specific call. + + Returns: + A tuple containing the (potentially modified) request_payload + and http_kwargs. + """ diff --git a/tests/client/test_auth_middleware.py b/tests/client/test_auth_middleware.py new file mode 100644 index 00000000..e03b9759 --- /dev/null +++ b/tests/client/test_auth_middleware.py @@ -0,0 +1,384 @@ +from typing import Any + +import httpx +import pytest +import respx + +from a2a.client import A2AClient, ClientCallContext, ClientCallInterceptor +from a2a.client.auth import AuthInterceptor, InMemoryContextCredentialStore +from a2a.types import ( + APIKeySecurityScheme, + AgentCapabilities, + AgentCard, + AuthorizationCodeOAuthFlow, + In, + OAuth2SecurityScheme, + OAuthFlows, + OpenIdConnectSecurityScheme, + SecurityScheme, + SendMessageRequest, +) + + +# A simple mock interceptor for testing basic middleware functionality +class HeaderInterceptor(ClientCallInterceptor): + def __init__(self, header_name: str, header_value: str): + self.header_name = header_name + self.header_value = header_value + + async def intercept( + self, + method_name: str, + request_payload: dict[str, Any], + http_kwargs: dict[str, Any], + agent_card: AgentCard | None, + context: ClientCallContext | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + headers = http_kwargs.get('headers', {}) + headers[self.header_name] = self.header_value + http_kwargs['headers'] = headers + return request_payload, http_kwargs + + +@pytest.mark.asyncio +@respx.mock +async def test_client_with_simple_interceptor(): + """ + Tests that a basic interceptor is called and successfully + modifies the outgoing request headers. + """ + # Arrange + test_url = 'http://fake-agent.com/rpc' + header_interceptor = HeaderInterceptor('X-Test-Header', 'Test-Value-123') + async with httpx.AsyncClient() as http_client: + client = A2AClient( + httpx_client=http_client, + url=test_url, + interceptors=[header_interceptor], + ) + + # Mock the HTTP response with a minimal valid success response + minimal_success_response = { + 'jsonrpc': '2.0', + 'id': '1', + 'result': { + 'kind': 'message', + 'messageId': 'response-msg', + 'role': 'agent', + 'parts': [], + }, + } + respx.post(test_url).mock( + return_value=httpx.Response(200, json=minimal_success_response) + ) + + # Act + await client.send_message( + request=SendMessageRequest( + id='1', + params={ + 'message': { + 'messageId': 'msg1', + 'role': 'user', + 'parts': [], + } + }, + ) + ) + + # Assert + assert len(respx.calls) == 1 + request = respx.calls.last.request + assert 'x-test-header' in request.headers + assert request.headers['x-test-header'] == 'Test-Value-123' + + +@pytest.mark.asyncio +async def test_in_memory_context_credential_store(): + """ + Tests the functionality of the InMemoryContextCredentialStore to ensure + it correctly stores and retrieves credentials based on sessionId. + """ + # Arrange + store = InMemoryContextCredentialStore() + session_id = 'test-session-123' + scheme_name = 'test-scheme' + credential = 'test-token' + + # Act + await store.set_credentials(session_id, scheme_name, credential) + + # Assert: Successful retrieval + context = ClientCallContext(state={'sessionId': session_id}) + retrieved_credential = await store.get_credentials(scheme_name, context) + assert retrieved_credential == credential + + # Assert: Retrieval with wrong session ID returns None + wrong_context = ClientCallContext(state={'sessionId': 'wrong-session'}) + retrieved_credential_wrong = await store.get_credentials( + scheme_name, wrong_context + ) + assert retrieved_credential_wrong is None + + # Assert: Retrieval with no context returns None + retrieved_credential_none = await store.get_credentials(scheme_name, None) + assert retrieved_credential_none is None + + # Assert: Retrieval with context but no sessionId returns None + empty_context = ClientCallContext(state={}) + retrieved_credential_empty = await store.get_credentials( + scheme_name, empty_context + ) + assert retrieved_credential_empty is None + + +@pytest.mark.asyncio +@respx.mock +async def test_auth_interceptor_with_api_key(): + """ + Tests the authentication flow with an API key in the header. + """ + # Arrange + test_url = 'http://apikey-agent.com/rpc' + session_id = 'user-session-2' + scheme_name = 'apiKeyAuth' + api_key = 'secret-api-key' + + cred_store = InMemoryContextCredentialStore() + await cred_store.set_credentials(session_id, scheme_name, api_key) + + auth_interceptor = AuthInterceptor(credential_service=cred_store) + + api_key_scheme_params = { + 'type': 'apiKey', + 'name': 'X-API-Key', + 'in': In.header, + } + + agent_card = AgentCard( + url=test_url, + name='ApiKeyBot', + description='A bot that requires an API Key', + version='1.0', + defaultInputModes=[], + defaultOutputModes=[], + skills=[], + capabilities=AgentCapabilities(), + security=[{scheme_name: []}], + securitySchemes={ + scheme_name: SecurityScheme( + root=APIKeySecurityScheme(**api_key_scheme_params) + ) + }, + ) + + async with httpx.AsyncClient() as http_client: + client = A2AClient( + httpx_client=http_client, + agent_card=agent_card, + interceptors=[auth_interceptor], + ) + + minimal_success_response = { + 'jsonrpc': '2.0', + 'id': '1', + 'result': { + 'kind': 'message', + 'messageId': 'response-msg', + 'role': 'agent', + 'parts': [], + }, + } + respx.post(test_url).mock( + return_value=httpx.Response(200, json=minimal_success_response) + ) + + # Act + context = ClientCallContext(state={'sessionId': session_id}) + await client.send_message( + request=SendMessageRequest( + id='1', + params={ + 'message': { + 'messageId': 'msg1', + 'role': 'user', + 'parts': [], + } + }, + ), + context=context, + ) + + # Assert + assert len(respx.calls) == 1 + request = respx.calls.last.request + assert 'x-api-key' in request.headers + assert request.headers['x-api-key'] == api_key + + +@pytest.mark.asyncio +@respx.mock +async def test_auth_interceptor_with_oauth2_scheme(): + """ + Tests the AuthInterceptor with an OAuth2 security scheme defined in AgentCard. + Ensures it correctly sets the Authorization: Bearer header. + """ + test_url = 'http://oauth-agent.com/rpc' + session_id = 'user-session-oauth' + scheme_name = 'myOAuthScheme' + access_token = 'secret-oauth-access-token' + + cred_store = InMemoryContextCredentialStore() + await cred_store.set_credentials(session_id, scheme_name, access_token) + + auth_interceptor = AuthInterceptor(credential_service=cred_store) + + oauth_flows = OAuthFlows( + authorizationCode=AuthorizationCodeOAuthFlow( + authorizationUrl='http://provider.com/auth', + tokenUrl='http://provider.com/token', + scopes={'read': 'Read scope'}, + ) + ) + + agent_card = AgentCard( + url=test_url, + name='OAuthBot', + description='A bot that uses OAuth2', + version='1.0', + defaultInputModes=[], + defaultOutputModes=[], + skills=[], + capabilities=AgentCapabilities(), + security=[{scheme_name: ['read']}], + securitySchemes={ + scheme_name: SecurityScheme( + root=OAuth2SecurityScheme(type='oauth2', flows=oauth_flows) + ) + }, + ) + + async with httpx.AsyncClient() as http_client: + client = A2AClient( + httpx_client=http_client, + agent_card=agent_card, + interceptors=[auth_interceptor], + ) + + minimal_success_response = { + 'jsonrpc': '2.0', + 'id': 'oauth_test_1', + 'result': { + 'kind': 'message', + 'messageId': 'response-msg-oauth', + 'role': 'agent', + 'parts': [], + }, + } + respx.post(test_url).mock( + return_value=httpx.Response(200, json=minimal_success_response) + ) + + # Act + context = ClientCallContext(state={'sessionId': session_id}) + await client.send_message( + request=SendMessageRequest( + id='oauth_test_1', + params={ + 'message': { + 'messageId': 'msg-oauth', + 'role': 'user', + 'parts': [], + } + }, + ), + context=context, + ) + + # Assert + assert len(respx.calls) == 1 + request_sent = respx.calls.last.request + assert 'Authorization' in request_sent.headers + assert request_sent.headers['Authorization'] == f'Bearer {access_token}' + + +@pytest.mark.asyncio +@respx.mock +async def test_auth_interceptor_with_oidc_scheme(): + """ + Tests the AuthInterceptor with an OpenIdConnectSecurityScheme. + Ensures it correctly sets the Authorization: Bearer header. + """ + # Arrange + test_url = 'http://oidc-agent.com/rpc' + session_id = 'user-session-oidc' + scheme_name = 'myOidcScheme' + id_token = 'secret-oidc-id-token' + + cred_store = InMemoryContextCredentialStore() + await cred_store.set_credentials(session_id, scheme_name, id_token) + + auth_interceptor = AuthInterceptor(credential_service=cred_store) + + agent_card = AgentCard( + url=test_url, + name='OidcBot', + description='A bot that uses OpenID Connect', + version='1.0', + defaultInputModes=[], + defaultOutputModes=[], + skills=[], + capabilities=AgentCapabilities(), + security=[{scheme_name: []}], + securitySchemes={ + scheme_name: SecurityScheme( + root=OpenIdConnectSecurityScheme( + type='openIdConnect', + openIdConnectUrl='http://provider.com/.well-known/openid-configuration', + ) + ) + }, + ) + + async with httpx.AsyncClient() as http_client: + client = A2AClient( + httpx_client=http_client, + agent_card=agent_card, + interceptors=[auth_interceptor], + ) + + minimal_success_response = { + 'jsonrpc': '2.0', + 'id': 'oidc_test_1', + 'result': { + 'kind': 'message', + 'messageId': 'response-msg-oidc', + 'role': 'agent', + 'parts': [], + }, + } + respx.post(test_url).mock( + return_value=httpx.Response(200, json=minimal_success_response) + ) + + # Act + context = ClientCallContext(state={'sessionId': session_id}) + await client.send_message( + request=SendMessageRequest( + id='oidc_test_1', + params={ + 'message': { + 'messageId': 'msg-oidc', + 'role': 'user', + 'parts': [], + } + }, + ), + context=context, + ) + + # Assert + assert len(respx.calls) == 1 + request_sent = respx.calls.last.request + assert 'Authorization' in request_sent.headers + assert request_sent.headers['Authorization'] == f'Bearer {id_token}'