Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ dev = [
"autoflake",
"no_implicit_optional",
"trio",
"uvicorn>=0.35.0",
]

[[tool.uv.index]]
Expand Down
145 changes: 145 additions & 0 deletions tests/e2e/push_notifications/agent_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import httpx

from fastapi import FastAPI

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.apps import A2ARESTFastAPIApplication
from a2a.server.events import EventQueue
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import (
BasePushNotificationSender,
InMemoryPushNotificationConfigStore,
InMemoryTaskStore,
TaskUpdater,
)
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentSkill,
InvalidParamsError,
Message,
Task,
)
from a2a.utils import (
new_agent_text_message,
new_task,
)
from a2a.utils.errors import ServerError


def test_agent_card(url: str) -> AgentCard:
"""Returns an agent card for the test agent."""
return AgentCard(
name='Test Agent',
description='Just a test agent',
url=url,
version='1.0.0',
default_input_modes=['text'],
default_output_modes=['text'],
capabilities=AgentCapabilities(streaming=True, push_notifications=True),
skills=[
AgentSkill(
id='greeting',
name='Greeting Agent',
description='just greets the user',
tags=['greeting'],
examples=['Hello Agent!', 'How are you?'],
)
],
supports_authenticated_extended_card=True,
)


class TestAgent:
"""Agent for push notification testing."""

async def invoke(
self, updater: TaskUpdater, msg: Message, task: Task
) -> None:
# Fail for unsupported messages.
if (
not msg.parts
or len(msg.parts) != 1
or msg.parts[0].root.kind != 'text'
):
await updater.failed(
new_agent_text_message(
'Unsupported message.', task.context_id, task.id
)
)
return
text_message = msg.parts[0].root.text

# Simple request-response flow.
if text_message == 'Hello Agent!':
await updater.complete(
new_agent_text_message('Hello User!', task.context_id, task.id)
)

# Flow with user input required: "How are you?" -> "Good! How are you?" -> "Good" -> "Amazing".
elif text_message == 'How are you?':
await updater.requires_input(
new_agent_text_message(
'Good! How are you?', task.context_id, task.id
)
)
elif text_message == 'Good':
await updater.complete(
new_agent_text_message('Amazing', task.context_id, task.id)
)

# Fail for unsupported messages.
else:
await updater.failed(
new_agent_text_message(
'Unsupported message.', task.context_id, task.id
)
)


class TestAgentExecutor(AgentExecutor):
"""Test AgentExecutor implementation."""

def __init__(self) -> None:
self.agent = TestAgent()

async def execute(
self,
context: RequestContext,
event_queue: EventQueue,
) -> None:
if not context.message:
raise ServerError(error=InvalidParamsError(message='No message'))

task = context.current_task
if not task:
task = new_task(context.message)
await event_queue.enqueue_event(task)
updater = TaskUpdater(event_queue, task.id, task.context_id)

await self.agent.invoke(updater, context.message, task)

async def cancel(
self, context: RequestContext, event_queue: EventQueue
) -> None:
raise NotImplementedError('cancel not supported')


def create_agent_app(
url: str, notification_client: httpx.AsyncClient
) -> FastAPI:
"""Creates a new HTTP+REST FastAPI application for the test agent."""
push_config_store = InMemoryPushNotificationConfigStore()
app = A2ARESTFastAPIApplication(
agent_card=test_agent_card(url),
http_handler=DefaultRequestHandler(
agent_executor=TestAgentExecutor(),
task_store=InMemoryTaskStore(),
push_config_store=push_config_store,
push_sender=BasePushNotificationSender(
httpx_client=notification_client,
config_store=push_config_store,
),
),
)
return app.build()
69 changes: 69 additions & 0 deletions tests/e2e/push_notifications/notifications_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio

from typing import Annotated

from fastapi import FastAPI, HTTPException, Path, Request
from pydantic import BaseModel, ValidationError

from a2a.types import Task


class Notification(BaseModel):
"""Encapsulates default push notification data."""

task: Task
token: str


def create_notifications_app() -> FastAPI:
"""Creates a simple push notification ingesting HTTP+REST application."""
app = FastAPI()
store_lock = asyncio.Lock()
store: dict[str, list[Notification]] = {}

@app.post('/notifications')
async def add_notification(request: Request):
"""Endpoint for injesting notifications from agents. It receives a JSON
payload and stores it in-memory.
"""
token = request.headers.get('x-a2a-notification-token')
if not token:
raise HTTPException(
status_code=400,
detail='Missing "x-a2a-notification-token" header.',
)
try:
task = Task.model_validate(await request.json())
except ValidationError as e:
raise HTTPException(status_code=400, detail=str(e))

async with store_lock:
if task.id not in store:
store[task.id] = []
store[task.id].append(
Notification(
task=task,
token=token,
)
)
return {
'status': 'received',
}

@app.get('/tasks/{task_id}/notifications')
async def list_notifications_by_task(
task_id: Annotated[
str, Path(title='The ID of the task to list the notifications for.')
],
):
"""Helper endpoint for retrieving injested notifications for a given task."""
async with store_lock:
notifications = store.get(task_id, [])
return {'notifications': notifications}

@app.get('/health')
def health_check():
"""Helper endpoint for checking if the server is up."""
return {'status': 'ok'}

return app
Loading
Loading