Skip to content

Commit 3cf66ec

Browse files
committed
use gather for async dispatch
1 parent 65e5909 commit 3cf66ec

File tree

3 files changed

+165
-7
lines changed

3 files changed

+165
-7
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,11 @@ async def on_get_task_push_notification_config(
411411
params.id
412412
)
413413
if not push_notification_config or not push_notification_config[0]:
414-
raise ServerError(error=InternalError())
414+
raise ServerError(
415+
error=InternalError(
416+
message='Push notification config not found'
417+
)
418+
)
415419

416420
return TaskPushNotificationConfig(
417421
taskId=params.id, pushNotificationConfig=push_notification_config[0]

src/a2a/server/tasks/base_push_notification_sender.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23

34
import httpx
@@ -15,7 +16,11 @@
1516
class BasePushNotificationSender(PushNotificationSender):
1617
"""Base implementation of PushNotificationSender interface."""
1718

18-
def __init__(self, httpx_client: httpx.AsyncClient, config_store: PushNotificationConfigStore) -> None:
19+
def __init__(
20+
self,
21+
httpx_client: httpx.AsyncClient,
22+
config_store: PushNotificationConfigStore,
23+
) -> None:
1924
"""Initializes the BasePushNotificationSender.
2025
2126
Args:
@@ -31,16 +36,32 @@ async def send_notification(self, task: Task) -> None:
3136
if not push_configs:
3237
return
3338

34-
for push_info in push_configs:
35-
await self._dispatch_notification(task, push_info)
39+
awaitables = [
40+
self._dispatch_notification(task, push_info)
41+
for push_info in push_configs
42+
]
43+
results = await asyncio.gather(*awaitables)
3644

37-
async def _dispatch_notification(self, task: Task, push_info: PushNotificationConfig) -> None:
45+
if not all(results):
46+
logger.warning(
47+
f'Some push notifications failed to send for task_id={task.id}'
48+
)
49+
50+
async def _dispatch_notification(
51+
self, task: Task, push_info: PushNotificationConfig
52+
) -> bool:
3853
url = push_info.url
3954
try:
4055
response = await self._client.post(
4156
url, json=task.model_dump(mode='json', exclude_none=True)
4257
)
4358
response.raise_for_status()
44-
logger.info(f'Push-notification sent for URL: {url}')
59+
logger.info(
60+
f'Push-notification sent for task_id={task.id} to URL: {url}'
61+
)
62+
return True
4563
except Exception as e:
46-
logger.error(f'Error sending push-notification: {e}')
64+
logger.error(
65+
f'Error sending push-notification for task_id={task.id} to URL: {url}. Error: {e}'
66+
)
67+
return False
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from a2a.server.tasks.base_push_notification_sender import (
2+
BasePushNotificationSender,
3+
)
4+
5+
import unittest
6+
7+
from unittest.mock import AsyncMock, MagicMock, patch
8+
9+
import httpx
10+
11+
from a2a.types import (
12+
PushNotificationConfig,
13+
Task,
14+
TaskState,
15+
TaskStatus,
16+
)
17+
18+
19+
def create_sample_task(task_id='task123', status_state=TaskState.completed):
20+
return Task(
21+
id=task_id,
22+
contextId='ctx456',
23+
status=TaskStatus(state=status_state),
24+
)
25+
26+
27+
def create_sample_push_config(
28+
url='http://example.com/callback', config_id='cfg1'
29+
):
30+
return PushNotificationConfig(id=config_id, url=url)
31+
32+
33+
class TestBasePushNotificationSender(unittest.IsolatedAsyncioTestCase):
34+
def setUp(self):
35+
self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient)
36+
self.mock_config_store = AsyncMock()
37+
self.sender = BasePushNotificationSender(
38+
httpx_client=self.mock_httpx_client,
39+
config_store=self.mock_config_store,
40+
)
41+
42+
def test_constructor_stores_client_and_config_store(self):
43+
self.assertEqual(self.sender._client, self.mock_httpx_client)
44+
self.assertEqual(self.sender._config_store, self.mock_config_store)
45+
46+
async def test_send_notification_success(self):
47+
task_id = 'task_send_success'
48+
task_data = create_sample_task(task_id=task_id)
49+
config = create_sample_push_config(url='http://notify.me/here')
50+
self.mock_config_store.get_info.return_value = [config]
51+
52+
mock_response = AsyncMock(spec=httpx.Response)
53+
mock_response.status_code = 200
54+
self.mock_httpx_client.post.return_value = mock_response
55+
56+
await self.sender.send_notification(task_data)
57+
58+
self.mock_config_store.get_info.assert_awaited_once_with
59+
60+
# assert httpx_client post method got invoked with right parameters
61+
self.mock_httpx_client.post.assert_awaited_once_with(
62+
config.url,
63+
json=task_data.model_dump(mode='json', exclude_none=True),
64+
)
65+
mock_response.raise_for_status.assert_called_once()
66+
67+
async def test_send_notification_no_config(self):
68+
task_id = 'task_send_no_config'
69+
task_data = create_sample_task(task_id=task_id)
70+
self.mock_config_store.get_info.return_value = []
71+
72+
await self.sender.send_notification(task_data)
73+
74+
self.mock_config_store.get_info.assert_awaited_once_with(task_id)
75+
self.mock_httpx_client.post.assert_not_called()
76+
77+
@patch('a2a.server.tasks.base_push_notification_sender.logger')
78+
async def test_send_notification_http_status_error(
79+
self, mock_logger: MagicMock
80+
):
81+
task_id = 'task_send_http_err'
82+
task_data = create_sample_task(task_id=task_id)
83+
config = create_sample_push_config(url='http://notify.me/http_error')
84+
self.mock_config_store.get_info.return_value = [config]
85+
86+
mock_response = MagicMock(spec=httpx.Response)
87+
mock_response.status_code = 404
88+
mock_response.text = 'Not Found'
89+
http_error = httpx.HTTPStatusError(
90+
'Not Found', request=MagicMock(), response=mock_response
91+
)
92+
self.mock_httpx_client.post.side_effect = http_error
93+
94+
await self.sender.send_notification(task_data)
95+
96+
self.mock_config_store.get_info.assert_awaited_once_with(task_id)
97+
self.mock_httpx_client.post.assert_awaited_once_with(
98+
config.url,
99+
json=task_data.model_dump(mode='json', exclude_none=True),
100+
)
101+
mock_logger.error.assert_called_once()
102+
103+
async def test_send_notification_multiple_configs(self):
104+
task_id = 'task_multiple_configs'
105+
task_data = create_sample_task(task_id=task_id)
106+
config1 = create_sample_push_config(
107+
url='http://notify.me/cfg1', config_id='cfg1'
108+
)
109+
config2 = create_sample_push_config(
110+
url='http://notify.me/cfg2', config_id='cfg2'
111+
)
112+
self.mock_config_store.get_info.return_value = [config1, config2]
113+
114+
mock_response = AsyncMock(spec=httpx.Response)
115+
mock_response.status_code = 200
116+
self.mock_httpx_client.post.return_value = mock_response
117+
118+
await self.sender.send_notification(task_data)
119+
120+
self.mock_config_store.get_info.assert_awaited_once_with(task_id)
121+
self.assertEqual(self.mock_httpx_client.post.call_count, 2)
122+
123+
# Check calls for config1
124+
self.mock_httpx_client.post.assert_any_call(
125+
config1.url,
126+
json=task_data.model_dump(mode='json', exclude_none=True),
127+
)
128+
# Check calls for config2
129+
self.mock_httpx_client.post.assert_any_call(
130+
config2.url,
131+
json=task_data.model_dump(mode='json', exclude_none=True),
132+
)
133+
mock_response.raise_for_status.call_count = 2

0 commit comments

Comments
 (0)