Skip to content

Commit 78a58cf

Browse files
author
lkawka
committed
Verify that notification tokens were properly passed around
1 parent 771c1ec commit 78a58cf

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

tests/e2e/push_notifications/notifications_app.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,49 @@
33
from typing import Annotated
44

55
from fastapi import FastAPI, HTTPException, Path, Request
6+
from pydantic import BaseModel, ValidationError
7+
8+
from a2a.types import Task
9+
10+
11+
class Notification(BaseModel):
12+
"""Encapsulates default push notification data."""
13+
14+
task: Task
15+
token: str
616

717

818
def create_notifications_app() -> FastAPI:
919
"""Creates a simple push notification ingesting HTTP+REST application."""
1020
app = FastAPI()
1121
store_lock = asyncio.Lock()
12-
store: dict[str, list] = {}
22+
store: dict[str, list[Notification]] = {}
1323

1424
@app.post('/notifications')
1525
async def add_notification(request: Request):
1626
"""Endpoint for injesting notifications from agents. It receives a JSON
1727
payload and stores it in-memory.
1828
"""
19-
if not request.headers.get('x-a2a-notification-token'):
29+
token = request.headers.get('x-a2a-notification-token')
30+
if not token:
2031
raise HTTPException(
2132
status_code=400,
2233
detail='Missing "x-a2a-notification-token" header.',
2334
)
24-
payload = await request.json()
25-
task_id = payload.get('id')
26-
if not task_id:
27-
raise HTTPException(
28-
status_code=400, detail='Missing "id" in notification payload.'
29-
)
35+
try:
36+
task = Task.model_validate(await request.json())
37+
except ValidationError as e:
38+
raise HTTPException(status_code=400, detail=str(e))
39+
3040
async with store_lock:
31-
if task_id not in store:
32-
store[task_id] = []
33-
store[task_id].append(payload)
41+
if task.id not in store:
42+
store[task.id] = []
43+
store[task.id].append(
44+
Notification(
45+
task=task,
46+
token=token,
47+
)
48+
)
3449
return {
3550
'status': 'received',
3651
}

tests/e2e/push_notifications/test_default_push_notification_support.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest_asyncio
1010

1111
from agent_app import create_agent_app
12-
from notifications_app import create_notifications_app
12+
from notifications_app import Notification, create_notifications_app
1313
from utils import (
1414
create_app_process,
1515
find_free_port,
@@ -115,14 +115,15 @@ async def test_notification_triggering_with_in_message_config_e2e(
115115
Tests push notification triggering for in-message push notification config.
116116
"""
117117
# Create an A2A client with a push notification config.
118+
token = uuid.uuid4().hex
118119
a2a_client = ClientFactory(
119120
ClientConfig(
120121
supported_transports=[TransportProtocol.http_json],
121122
push_notification_configs=[
122123
PushNotificationConfig(
123124
id='in-message-config',
124125
url=f'{notifications_server}/notifications',
125-
token=uuid.uuid4().hex,
126+
token=token,
126127
)
127128
],
128129
)
@@ -150,7 +151,9 @@ async def test_notification_triggering_with_in_message_config_e2e(
150151
f'{notifications_server}/tasks/{task.id}/notifications',
151152
n=1,
152153
)
153-
assert notifications[0].status.state == 'completed'
154+
assert notifications[0].token == token
155+
assert notifications[0].task.id == task.id
156+
assert notifications[0].task.status.state == 'completed'
154157

155158

156159
@pytest.mark.asyncio
@@ -192,13 +195,14 @@ async def test_notification_triggering_after_config_change_e2e(
192195
assert len(response.json().get('notifications', [])) == 0
193196

194197
# Set the push notification config.
198+
token = uuid.uuid4().hex
195199
await a2a_client.set_task_callback(
196200
TaskPushNotificationConfig(
197201
task_id=task.id,
198202
push_notification_config=PushNotificationConfig(
199203
id='after-config-change',
200204
url=f'{notifications_server}/notifications',
201-
token=uuid.uuid4().hex,
205+
token=token,
202206
),
203207
)
204208
)
@@ -223,15 +227,17 @@ async def test_notification_triggering_after_config_change_e2e(
223227
f'{notifications_server}/tasks/{task.id}/notifications',
224228
n=1,
225229
)
226-
assert notifications[0].status.state == 'completed'
230+
assert notifications[0].task.id == task.id
231+
assert notifications[0].task.status.state == 'completed'
232+
assert notifications[0].token == token
227233

228234

229235
async def wait_for_n_notifications(
230236
http_client: httpx.AsyncClient,
231237
url: str,
232238
n: int,
233239
timeout: int = 3,
234-
) -> list[Task]:
240+
) -> list[Notification]:
235241
"""
236242
Queries the notification URL until the desired number of notifications
237243
is received or the timeout is reached.
@@ -243,7 +249,7 @@ async def wait_for_n_notifications(
243249
assert response.status_code == 200
244250
notifications = response.json()['notifications']
245251
if len(notifications) == n:
246-
return [Task.model_validate(n) for n in notifications]
252+
return [Notification.model_validate(n) for n in notifications]
247253
if time.time() - start_time > timeout:
248254
raise TimeoutError(
249255
f'Notification retrieval timed out. Got {len(notifications)} notification(s), want {n}. Retrieved notifications: {notifications}.'

0 commit comments

Comments
 (0)