Skip to content

Commit ce1fc39

Browse files
committed
Add tests for inmemory push notifications
1 parent 4934b48 commit ce1fc39

File tree

2 files changed

+271
-3
lines changed

2 files changed

+271
-3
lines changed

src/a2a/server/tasks/inmemory_push_notification_config_store.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ class InMemoryPushNotificationConfigStore(PushNotificationConfigStore):
1515
1616
Stores push notification configurations in memory
1717
"""
18+
1819
def __init__(self) -> None:
1920
"""Initializes the InMemoryPushNotificationConfigStore."""
2021
self.lock = asyncio.Lock()
21-
self._push_notification_infos: dict[str, list[PushNotificationConfig]] = {}
22+
self._push_notification_infos: dict[
23+
str, list[PushNotificationConfig]
24+
] = {}
2225

2326
async def set_info(
2427
self, task_id: str, notification_config: PushNotificationConfig
@@ -38,13 +41,14 @@ async def set_info(
3841

3942
self._push_notification_infos[task_id].append(notification_config)
4043

41-
4244
async def get_info(self, task_id: str) -> list[PushNotificationConfig]:
4345
"""Retrieves the push notification configuration for a task from memory."""
4446
async with self.lock:
4547
return self._push_notification_infos.get(task_id) or []
4648

47-
async def delete_info(self, task_id: str, config_id: str | None = None) -> None:
49+
async def delete_info(
50+
self, task_id: str, config_id: str | None = None
51+
) -> None:
4852
"""Deletes the push notification configuration for a task from memory."""
4953
async with self.lock:
5054
if config_id is None:
@@ -59,3 +63,6 @@ async def delete_info(self, task_id: str, config_id: str | None = None) -> None:
5963
if config.id == config_id:
6064
configurations.remove(config)
6165
break
66+
67+
if len(configurations) == 0:
68+
del self._push_notification_infos[task_id]
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import unittest
2+
3+
from unittest.mock import AsyncMock, MagicMock, patch
4+
5+
import httpx
6+
7+
from a2a.server.tasks.inmemory_push_notification_config_store import (
8+
InMemoryPushNotificationConfigStore,
9+
)
10+
from a2a.server.tasks.base_push_notification_sender import (
11+
BasePushNotificationSender,
12+
)
13+
from a2a.types import PushNotificationConfig, Task, TaskState, TaskStatus
14+
15+
16+
# Suppress logging for cleaner test output, can be enabled for debugging
17+
# logging.disable(logging.CRITICAL)
18+
19+
20+
def create_sample_task(task_id='task123', status_state=TaskState.completed):
21+
return Task(
22+
id=task_id,
23+
contextId='ctx456',
24+
status=TaskStatus(state=status_state),
25+
)
26+
27+
28+
def create_sample_push_config(
29+
url='http://example.com/callback', config_id='cfg1'
30+
):
31+
return PushNotificationConfig(id=config_id, url=url)
32+
33+
34+
class TestInMemoryPushNotifier(unittest.IsolatedAsyncioTestCase):
35+
def setUp(self):
36+
self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient)
37+
self.config_store = InMemoryPushNotificationConfigStore()
38+
self.notifier = BasePushNotificationSender(
39+
httpx_client=self.mock_httpx_client, config_store=self.config_store
40+
) # Corrected argument name
41+
42+
def test_constructor_stores_client(self):
43+
self.assertEqual(self.notifier._client, self.mock_httpx_client)
44+
45+
async def test_set_info_adds_new_config(self):
46+
task_id = 'task_new'
47+
config = create_sample_push_config(url='http://new.url/callback')
48+
49+
await self.config_store.set_info(task_id, config)
50+
51+
self.assertIn(task_id, self.config_store._push_notification_infos)
52+
self.assertEqual(
53+
self.config_store._push_notification_infos[task_id], [config]
54+
)
55+
56+
async def test_set_info_appends_to_existing_config(self):
57+
task_id = 'task_update'
58+
initial_config = create_sample_push_config(
59+
url='http://initial.url/callback', config_id='cfg_initial'
60+
)
61+
await self.config_store.set_info(task_id, initial_config)
62+
63+
updated_config = create_sample_push_config(
64+
url='http://updated.url/callback', config_id='cfg_updated'
65+
)
66+
await self.config_store.set_info(task_id, updated_config)
67+
68+
self.assertIn(task_id, self.config_store._push_notification_infos)
69+
self.assertEqual(
70+
self.config_store._push_notification_infos[task_id][0],
71+
initial_config,
72+
)
73+
self.assertEqual(
74+
self.config_store._push_notification_infos[task_id][1],
75+
updated_config,
76+
)
77+
78+
async def test_set_info_without_config_id(self):
79+
task_id = 'task1'
80+
initial_config = PushNotificationConfig(
81+
url='http://initial.url/callback'
82+
)
83+
await self.config_store.set_info(task_id, initial_config)
84+
85+
assert (
86+
self.config_store._push_notification_infos[task_id][0].id == task_id
87+
)
88+
89+
updated_config = PushNotificationConfig(
90+
url='http://initial.url/callback_new'
91+
)
92+
await self.config_store.set_info(task_id, updated_config)
93+
94+
self.assertIn(task_id, self.config_store._push_notification_infos)
95+
assert len(self.config_store._push_notification_infos[task_id]) == 1
96+
self.assertEqual(
97+
self.config_store._push_notification_infos[task_id][0].url,
98+
updated_config.url,
99+
)
100+
101+
async def test_get_info_existing_config(self):
102+
task_id = 'task_get_exist'
103+
config = create_sample_push_config(url='http://get.this/callback')
104+
await self.config_store.set_info(task_id, config)
105+
106+
retrieved_config = await self.config_store.get_info(task_id)
107+
self.assertEqual(retrieved_config, [config])
108+
109+
async def test_get_info_non_existent_config(self):
110+
task_id = 'task_get_non_exist'
111+
retrieved_config = await self.config_store.get_info(task_id)
112+
assert retrieved_config == []
113+
114+
async def test_delete_info_existing_config(self):
115+
task_id = 'task_delete_exist'
116+
config = create_sample_push_config(url='http://delete.this/callback')
117+
await self.config_store.set_info(task_id, config)
118+
119+
self.assertIn(task_id, self.config_store._push_notification_infos)
120+
await self.config_store.delete_info(task_id, config_id=config.id)
121+
self.assertNotIn(task_id, self.config_store._push_notification_infos)
122+
123+
async def test_delete_info_non_existent_config(self):
124+
task_id = 'task_delete_non_exist'
125+
# Ensure it doesn't raise an error
126+
try:
127+
await self.config_store.delete_info(task_id)
128+
except Exception as e:
129+
self.fail(
130+
f'delete_info raised {e} unexpectedly for nonexistent task_id'
131+
)
132+
self.assertNotIn(
133+
task_id, self.config_store._push_notification_infos
134+
) # Should still not be there
135+
136+
async def test_send_notification_success(self):
137+
task_id = 'task_send_success'
138+
task_data = create_sample_task(task_id=task_id)
139+
config = create_sample_push_config(url='http://notify.me/here')
140+
await self.config_store.set_info(task_id, config)
141+
142+
# Mock the post call to simulate success
143+
mock_response = AsyncMock(spec=httpx.Response)
144+
mock_response.status_code = 200
145+
self.mock_httpx_client.post.return_value = mock_response
146+
147+
await self.notifier.send_notification(task_data) # Pass only task_data
148+
149+
self.mock_httpx_client.post.assert_awaited_once()
150+
called_args, called_kwargs = self.mock_httpx_client.post.call_args
151+
self.assertEqual(called_args[0], config.url)
152+
self.assertEqual(
153+
called_kwargs['json'],
154+
task_data.model_dump(mode='json', exclude_none=True),
155+
)
156+
self.assertNotIn(
157+
'auth', called_kwargs
158+
) # auth is not passed by current implementation
159+
mock_response.raise_for_status.assert_called_once()
160+
161+
async def test_send_notification_no_config(self):
162+
task_id = 'task_send_no_config'
163+
task_data = create_sample_task(task_id=task_id)
164+
165+
await self.notifier.send_notification(task_data) # Pass only task_data
166+
167+
self.mock_httpx_client.post.assert_not_called()
168+
169+
@patch('a2a.server.tasks.base_push_notification_sender.logger')
170+
async def test_send_notification_http_status_error(
171+
self, mock_logger: MagicMock
172+
):
173+
task_id = 'task_send_http_err'
174+
task_data = create_sample_task(task_id=task_id)
175+
config = create_sample_push_config(url='http://notify.me/http_error')
176+
await self.config_store.set_info(task_id, config)
177+
178+
mock_response = MagicMock(
179+
spec=httpx.Response
180+
) # Use MagicMock for status_code attribute
181+
mock_response.status_code = 404
182+
mock_response.text = 'Not Found'
183+
http_error = httpx.HTTPStatusError(
184+
'Not Found', request=MagicMock(), response=mock_response
185+
)
186+
self.mock_httpx_client.post.side_effect = http_error
187+
188+
# The method should catch the error and log it, not re-raise
189+
await self.notifier.send_notification(task_data) # Pass only task_data
190+
191+
self.mock_httpx_client.post.assert_awaited_once()
192+
mock_logger.error.assert_called_once()
193+
# Check that the error message contains the generic part and the specific exception string
194+
self.assertIn(
195+
'Error sending push-notification', mock_logger.error.call_args[0][0]
196+
)
197+
self.assertIn(str(http_error), mock_logger.error.call_args[0][0])
198+
199+
@patch('a2a.server.tasks.base_push_notification_sender.logger')
200+
async def test_send_notification_request_error(
201+
self, mock_logger: MagicMock
202+
):
203+
task_id = 'task_send_req_err'
204+
task_data = create_sample_task(task_id=task_id)
205+
config = create_sample_push_config(url='http://notify.me/req_error')
206+
await self.config_store.set_info(task_id, config)
207+
208+
request_error = httpx.RequestError('Network issue', request=MagicMock())
209+
self.mock_httpx_client.post.side_effect = request_error
210+
211+
await self.notifier.send_notification(task_data) # Pass only task_data
212+
213+
self.mock_httpx_client.post.assert_awaited_once()
214+
mock_logger.error.assert_called_once()
215+
self.assertIn(
216+
'Error sending push-notification', mock_logger.error.call_args[0][0]
217+
)
218+
self.assertIn(str(request_error), mock_logger.error.call_args[0][0])
219+
220+
@patch('a2a.server.tasks.base_push_notification_sender.logger')
221+
async def test_send_notification_with_auth(self, mock_logger: MagicMock):
222+
task_id = 'task_send_auth'
223+
task_data = create_sample_task(task_id=task_id)
224+
auth_info = ('user', 'pass')
225+
config = create_sample_push_config(url='http://notify.me/auth')
226+
config.authentication = MagicMock() # Mocking the structure for auth
227+
config.authentication.schemes = ['basic'] # Assume basic for simplicity
228+
config.authentication.credentials = (
229+
auth_info # This might need to be a specific model
230+
)
231+
# For now, let's assume it's a tuple for basic auth
232+
# The actual PushNotificationAuthenticationInfo is more complex
233+
# For this test, we'll simplify and assume InMemoryPushNotifier
234+
# directly uses tuple for httpx's `auth` param if basic.
235+
# A more accurate test would construct the real auth model.
236+
# Given the current implementation of InMemoryPushNotifier,
237+
# it only supports basic auth via tuple.
238+
239+
await self.config_store.set_info(task_id, config)
240+
241+
mock_response = AsyncMock(spec=httpx.Response)
242+
mock_response.status_code = 200
243+
self.mock_httpx_client.post.return_value = mock_response
244+
245+
await self.notifier.send_notification(task_data) # Pass only task_data
246+
247+
self.mock_httpx_client.post.assert_awaited_once()
248+
called_args, called_kwargs = self.mock_httpx_client.post.call_args
249+
self.assertEqual(called_args[0], config.url)
250+
self.assertEqual(
251+
called_kwargs['json'],
252+
task_data.model_dump(mode='json', exclude_none=True),
253+
)
254+
self.assertNotIn(
255+
'auth', called_kwargs
256+
) # auth is not passed by current implementation
257+
mock_response.raise_for_status.assert_called_once()
258+
259+
260+
if __name__ == '__main__':
261+
unittest.main()

0 commit comments

Comments
 (0)