diff --git a/pyproject.toml b/pyproject.toml index 0d2cb75a..192e2151 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ dev = [ "autoflake", "no_implicit_optional", "trio", + "uvicorn>=0.35.0", ] [[tool.uv.index]] diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py new file mode 100644 index 00000000..1fa9bc54 --- /dev/null +++ b/tests/e2e/push_notifications/agent_app.py @@ -0,0 +1,145 @@ +import httpx + +from fastapi import FastAPI + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.events import EventQueue +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import ( + BasePushNotificationSender, + InMemoryPushNotificationConfigStore, + InMemoryTaskStore, + TaskUpdater, +) +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + InvalidParamsError, + Message, + Task, +) +from a2a.utils import ( + new_agent_text_message, + new_task, +) +from a2a.utils.errors import ServerError + + +def test_agent_card(url: str) -> AgentCard: + """Returns an agent card for the test agent.""" + return AgentCard( + name='Test Agent', + description='Just a test agent', + url=url, + version='1.0.0', + default_input_modes=['text'], + default_output_modes=['text'], + capabilities=AgentCapabilities(streaming=True, push_notifications=True), + skills=[ + AgentSkill( + id='greeting', + name='Greeting Agent', + description='just greets the user', + tags=['greeting'], + examples=['Hello Agent!', 'How are you?'], + ) + ], + supports_authenticated_extended_card=True, + ) + + +class TestAgent: + """Agent for push notification testing.""" + + async def invoke( + self, updater: TaskUpdater, msg: Message, task: Task + ) -> None: + # Fail for unsupported messages. + if ( + not msg.parts + or len(msg.parts) != 1 + or msg.parts[0].root.kind != 'text' + ): + await updater.failed( + new_agent_text_message( + 'Unsupported message.', task.context_id, task.id + ) + ) + return + text_message = msg.parts[0].root.text + + # Simple request-response flow. + if text_message == 'Hello Agent!': + await updater.complete( + new_agent_text_message('Hello User!', task.context_id, task.id) + ) + + # Flow with user input required: "How are you?" -> "Good! How are you?" -> "Good" -> "Amazing". + elif text_message == 'How are you?': + await updater.requires_input( + new_agent_text_message( + 'Good! How are you?', task.context_id, task.id + ) + ) + elif text_message == 'Good': + await updater.complete( + new_agent_text_message('Amazing', task.context_id, task.id) + ) + + # Fail for unsupported messages. + else: + await updater.failed( + new_agent_text_message( + 'Unsupported message.', task.context_id, task.id + ) + ) + + +class TestAgentExecutor(AgentExecutor): + """Test AgentExecutor implementation.""" + + def __init__(self) -> None: + self.agent = TestAgent() + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + if not context.message: + raise ServerError(error=InvalidParamsError(message='No message')) + + task = context.current_task + if not task: + task = new_task(context.message) + await event_queue.enqueue_event(task) + updater = TaskUpdater(event_queue, task.id, task.context_id) + + await self.agent.invoke(updater, context.message, task) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + raise NotImplementedError('cancel not supported') + + +def create_agent_app( + url: str, notification_client: httpx.AsyncClient +) -> FastAPI: + """Creates a new HTTP+REST FastAPI application for the test agent.""" + push_config_store = InMemoryPushNotificationConfigStore() + app = A2ARESTFastAPIApplication( + agent_card=test_agent_card(url), + http_handler=DefaultRequestHandler( + agent_executor=TestAgentExecutor(), + task_store=InMemoryTaskStore(), + push_config_store=push_config_store, + push_sender=BasePushNotificationSender( + httpx_client=notification_client, + config_store=push_config_store, + ), + ), + ) + return app.build() diff --git a/tests/e2e/push_notifications/notifications_app.py b/tests/e2e/push_notifications/notifications_app.py new file mode 100644 index 00000000..ed032dcb --- /dev/null +++ b/tests/e2e/push_notifications/notifications_app.py @@ -0,0 +1,69 @@ +import asyncio + +from typing import Annotated + +from fastapi import FastAPI, HTTPException, Path, Request +from pydantic import BaseModel, ValidationError + +from a2a.types import Task + + +class Notification(BaseModel): + """Encapsulates default push notification data.""" + + task: Task + token: str + + +def create_notifications_app() -> FastAPI: + """Creates a simple push notification ingesting HTTP+REST application.""" + app = FastAPI() + store_lock = asyncio.Lock() + store: dict[str, list[Notification]] = {} + + @app.post('/notifications') + async def add_notification(request: Request): + """Endpoint for injesting notifications from agents. It receives a JSON + payload and stores it in-memory. + """ + token = request.headers.get('x-a2a-notification-token') + if not token: + raise HTTPException( + status_code=400, + detail='Missing "x-a2a-notification-token" header.', + ) + try: + task = Task.model_validate(await request.json()) + except ValidationError as e: + raise HTTPException(status_code=400, detail=str(e)) + + async with store_lock: + if task.id not in store: + store[task.id] = [] + store[task.id].append( + Notification( + task=task, + token=token, + ) + ) + return { + 'status': 'received', + } + + @app.get('/tasks/{task_id}/notifications') + async def list_notifications_by_task( + task_id: Annotated[ + str, Path(title='The ID of the task to list the notifications for.') + ], + ): + """Helper endpoint for retrieving injested notifications for a given task.""" + async with store_lock: + notifications = store.get(task_id, []) + return {'notifications': notifications} + + @app.get('/health') + def health_check(): + """Helper endpoint for checking if the server is up.""" + return {'status': 'ok'} + + return app diff --git a/tests/e2e/push_notifications/test_default_push_notification_support.py b/tests/e2e/push_notifications/test_default_push_notification_support.py new file mode 100644 index 00000000..775bd7fb --- /dev/null +++ b/tests/e2e/push_notifications/test_default_push_notification_support.py @@ -0,0 +1,244 @@ +import asyncio +import time +import uuid + +import httpx +import pytest +import pytest_asyncio + +from agent_app import create_agent_app +from notifications_app import Notification, create_notifications_app +from utils import ( + create_app_process, + find_free_port, + wait_for_server_ready, +) + +from a2a.client import ( + ClientConfig, + ClientFactory, + minimal_agent_card, +) +from a2a.types import ( + Message, + Part, + PushNotificationConfig, + Role, + Task, + TaskPushNotificationConfig, + TaskState, + TextPart, + TransportProtocol, +) + + +@pytest.fixture(scope='module') +def notifications_server(): + """ + Starts a simple push notifications injesting server and yields its URL. + """ + host = '127.0.0.1' + port = find_free_port() + url = f'http://{host}:{port}' + + process = create_app_process(create_notifications_app(), host, port) + process.start() + try: + wait_for_server_ready(f'{url}/health') + except TimeoutError as e: + process.terminate() + raise e + + yield url + + process.terminate() + process.join() + + +@pytest_asyncio.fixture(scope='module') +async def notifications_client(): + """An async client fixture for calling the notifications server.""" + async with httpx.AsyncClient() as client: + yield client + + +@pytest.fixture(scope='module') +def agent_server(notifications_client: httpx.AsyncClient): + """Starts a test agent server and yields its URL.""" + host = '127.0.0.1' + port = find_free_port() + url = f'http://{host}:{port}' + + process = create_app_process( + create_agent_app(url, notifications_client), host, port + ) + process.start() + try: + wait_for_server_ready(f'{url}/v1/card') + except TimeoutError as e: + process.terminate() + raise e + + yield url + + process.terminate() + process.join() + + +@pytest_asyncio.fixture(scope='function') +async def http_client(): + """An async client fixture for test functions.""" + async with httpx.AsyncClient() as client: + yield client + + +@pytest.mark.asyncio +async def test_notification_triggering_with_in_message_config_e2e( + notifications_server: str, + agent_server: str, + http_client: httpx.AsyncClient, +): + """ + Tests push notification triggering for in-message push notification config. + """ + # Create an A2A client with a push notification config. + token = uuid.uuid4().hex + a2a_client = ClientFactory( + ClientConfig( + supported_transports=[TransportProtocol.http_json], + push_notification_configs=[ + PushNotificationConfig( + id='in-message-config', + url=f'{notifications_server}/notifications', + token=token, + ) + ], + ) + ).create(minimal_agent_card(agent_server, [TransportProtocol.http_json])) + + # Send a message and extract the returned task. + responses = [ + response + async for response in a2a_client.send_message( + Message( + message_id='hello-agent', + parts=[Part(root=TextPart(text='Hello Agent!'))], + role=Role.user, + ) + ) + ] + assert len(responses) == 1 + assert isinstance(responses[0], tuple) + assert isinstance(responses[0][0], Task) + task = responses[0][0] + + # Verify a single notification was sent. + notifications = await wait_for_n_notifications( + http_client, + f'{notifications_server}/tasks/{task.id}/notifications', + n=1, + ) + assert notifications[0].token == token + assert notifications[0].task.id == task.id + assert notifications[0].task.status.state == 'completed' + + +@pytest.mark.asyncio +async def test_notification_triggering_after_config_change_e2e( + notifications_server: str, agent_server: str, http_client: httpx.AsyncClient +): + """ + Tests notification triggering after setting the push notificaiton config in a seperate call. + """ + # Configure an A2A client without a push notification config. + a2a_client = ClientFactory( + ClientConfig( + supported_transports=[TransportProtocol.http_json], + ) + ).create(minimal_agent_card(agent_server, [TransportProtocol.http_json])) + + # Send a message and extract the returned task. + responses = [ + response + async for response in a2a_client.send_message( + Message( + message_id='how-are-you', + parts=[Part(root=TextPart(text='How are you?'))], + role=Role.user, + ) + ) + ] + assert len(responses) == 1 + assert isinstance(responses[0], tuple) + assert isinstance(responses[0][0], Task) + task = responses[0][0] + assert task.status.state == TaskState.input_required + + # Verify that no notification has been sent yet. + response = await http_client.get( + f'{notifications_server}/tasks/{task.id}/notifications' + ) + assert response.status_code == 200 + assert len(response.json().get('notifications', [])) == 0 + + # Set the push notification config. + token = uuid.uuid4().hex + await a2a_client.set_task_callback( + TaskPushNotificationConfig( + task_id=task.id, + push_notification_config=PushNotificationConfig( + id='after-config-change', + url=f'{notifications_server}/notifications', + token=token, + ), + ) + ) + + # Send another message that should trigger a push notification. + responses = [ + response + async for response in a2a_client.send_message( + Message( + task_id=task.id, + message_id='good', + parts=[Part(root=TextPart(text='Good'))], + role=Role.user, + ) + ) + ] + assert len(responses) == 1 + + # Verify that the push notification was sent. + notifications = await wait_for_n_notifications( + http_client, + f'{notifications_server}/tasks/{task.id}/notifications', + n=1, + ) + assert notifications[0].task.id == task.id + assert notifications[0].task.status.state == 'completed' + assert notifications[0].token == token + + +async def wait_for_n_notifications( + http_client: httpx.AsyncClient, + url: str, + n: int, + timeout: int = 3, +) -> list[Notification]: + """ + Queries the notification URL until the desired number of notifications + is received or the timeout is reached. + """ + start_time = time.time() + notifications = [] + while True: + response = await http_client.get(url) + assert response.status_code == 200 + notifications = response.json()['notifications'] + if len(notifications) == n: + return [Notification.model_validate(n) for n in notifications] + if time.time() - start_time > timeout: + raise TimeoutError( + f'Notification retrieval timed out. Got {len(notifications)} notification(s), want {n}. Retrieved notifications: {notifications}.' + ) + await asyncio.sleep(0.1) diff --git a/tests/e2e/push_notifications/utils.py b/tests/e2e/push_notifications/utils.py new file mode 100644 index 00000000..01d84a30 --- /dev/null +++ b/tests/e2e/push_notifications/utils.py @@ -0,0 +1,45 @@ +import contextlib +import socket +import time + +from multiprocessing import Process + +import httpx +import uvicorn + + +def find_free_port(): + """Finds and returns an available ephemeral localhost port.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + return s.getsockname()[1] + + +def run_server(app, host, port) -> None: + """Runs a uvicorn server.""" + uvicorn.run(app, host=host, port=port, log_level='warning') + + +def wait_for_server_ready(url: str, timeout: int = 10) -> None: + """Polls the provided URL endpoint until the server is up.""" + start_time = time.time() + while True: + with contextlib.suppress(httpx.ConnectError): + with httpx.Client() as client: + response = client.get(url) + if response.status_code == 200: + return + if time.time() - start_time > timeout: + raise TimeoutError( + f'Server at {url} failed to start after {timeout}s' + ) + time.sleep(0.1) + + +def create_app_process(app, host, port) -> Process: + """Creates a separate process for a given application.""" + return Process( + target=run_server, + args=(app, host, port), + daemon=True, + )