Skip to content

Commit e91af9c

Browse files
authored
Merge branch 'main' into improve-error-handling
2 parents d4c93af + 2126828 commit e91af9c

File tree

7 files changed

+188
-38
lines changed

7 files changed

+188
-38
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
InvalidParamsError,
3434
ListTaskPushNotificationConfigParams,
3535
Message,
36-
MessageSendConfiguration,
3736
MessageSendParams,
38-
PushNotificationConfig,
3937
Task,
4038
TaskIdParams,
4139
TaskNotFoundError,
@@ -202,18 +200,6 @@ async def _setup_message_execution(
202200
)
203201

204202
task = task_manager.update_with_message(params.message, task)
205-
if self.should_add_push_info(params):
206-
assert self._push_config_store is not None
207-
assert isinstance(
208-
params.configuration, MessageSendConfiguration
209-
)
210-
assert isinstance(
211-
params.configuration.pushNotificationConfig,
212-
PushNotificationConfig,
213-
)
214-
await self._push_config_store.set_info(
215-
task.id, params.configuration.pushNotificationConfig
216-
)
217203

218204
# Build request context
219205
request_context = await self._request_context_builder.build(
@@ -228,6 +214,16 @@ async def _setup_message_execution(
228214
# Always assign a task ID. We may not actually upgrade to a task, but
229215
# dictating the task ID at this layer is useful for tracking running
230216
# agents.
217+
218+
if (
219+
self._push_config_store
220+
and params.configuration
221+
and params.configuration.pushNotificationConfig
222+
):
223+
await self._push_config_store.set_info(
224+
task_id, params.configuration.pushNotificationConfig
225+
)
226+
231227
queue = await self._queue_manager.create_or_tap(task_id)
232228
result_aggregator = ResultAggregator(task_manager)
233229
# TODO: to manage the non-blocking flows.
@@ -333,16 +329,6 @@ async def on_message_send_stream(
333329
if isinstance(event, Task):
334330
self._validate_task_id_match(task_id, event.id)
335331

336-
if (
337-
self._push_config_store
338-
and params.configuration
339-
and params.configuration.pushNotificationConfig
340-
):
341-
await self._push_config_store.set_info(
342-
task_id,
343-
params.configuration.pushNotificationConfig,
344-
)
345-
346332
await self._send_push_notification_if_needed(
347333
task_id, result_aggregator
348334
)
@@ -509,11 +495,3 @@ async def on_delete_task_push_notification_config(
509495
await self._push_config_store.delete_info(
510496
params.id, params.pushNotificationConfigId
511497
)
512-
513-
def should_add_push_info(self, params: MessageSendParams) -> bool:
514-
"""Determines if push notification info should be set for a task."""
515-
return bool(
516-
self._push_config_store
517-
and params.configuration
518-
and params.configuration.pushNotificationConfig
519-
)

src/a2a/server/tasks/__init__.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Components for managing tasks within the A2A server."""
22

3+
import logging
4+
35
from a2a.server.tasks.base_push_notification_sender import (
46
BasePushNotificationSender,
57
)
6-
from a2a.server.tasks.database_task_store import DatabaseTaskStore
78
from a2a.server.tasks.inmemory_push_notification_config_store import (
89
InMemoryPushNotificationConfigStore,
910
)
@@ -18,6 +19,30 @@
1819
from a2a.server.tasks.task_updater import TaskUpdater
1920

2021

22+
logger = logging.getLogger(__name__)
23+
24+
try:
25+
from a2a.server.tasks.database_task_store import (
26+
DatabaseTaskStore, # type: ignore
27+
)
28+
except ImportError as e:
29+
_original_error = e
30+
# If the database task store is not available, we can still use in-memory stores.
31+
logger.debug(
32+
'DatabaseTaskStore not loaded. This is expected if database dependencies are not installed. Error: %s',
33+
e,
34+
)
35+
36+
class DatabaseTaskStore: # type: ignore
37+
"""Placeholder for DatabaseTaskStore when dependencies are not installed."""
38+
39+
def __init__(self, *args, **kwargs):
40+
raise ImportError(
41+
'To use DatabaseTaskStore, its dependencies must be installed. '
42+
'You can install them with \'pip install "a2a-sdk[sql]"\''
43+
) from _original_error
44+
45+
2146
__all__ = [
2247
'BasePushNotificationSender',
2348
'DatabaseTaskStore',

src/a2a/server/tasks/base_push_notification_sender.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@ async def _dispatch_notification(
5252
) -> bool:
5353
url = push_info.url
5454
try:
55+
headers = None
56+
if push_info.token:
57+
headers = {'X-A2A-Notification-Token': push_info.token}
5558
response = await self._client.post(
56-
url, json=task.model_dump(mode='json', exclude_none=True)
59+
url,
60+
json=task.model_dump(mode='json', exclude_none=True),
61+
headers=headers
5762
)
5863
response.raise_for_status()
5964
logger.info(

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,90 @@ async def get_current_result():
401401
mock_agent_executor.execute.assert_awaited_once()
402402

403403

404+
@pytest.mark.asyncio
405+
async def test_on_message_send_with_push_notification_no_existing_Task():
406+
"""Test on_message_send for new task sets push notification info if provided."""
407+
mock_task_store = AsyncMock(spec=TaskStore)
408+
mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore)
409+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
410+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
411+
412+
task_id = 'push_task_1'
413+
context_id = 'push_ctx_1'
414+
415+
mock_task_store.get.return_value = (
416+
None # Simulate new task scenario for TaskManager
417+
)
418+
419+
# Mock _request_context_builder.build to return a context with the generated/confirmed IDs
420+
mock_request_context = MagicMock(spec=RequestContext)
421+
mock_request_context.task_id = task_id
422+
mock_request_context.context_id = context_id
423+
mock_request_context_builder.build.return_value = mock_request_context
424+
425+
request_handler = DefaultRequestHandler(
426+
agent_executor=mock_agent_executor,
427+
task_store=mock_task_store,
428+
push_config_store=mock_push_notification_store,
429+
request_context_builder=mock_request_context_builder,
430+
)
431+
432+
push_config = PushNotificationConfig(url='http://callback.com/push')
433+
message_config = MessageSendConfiguration(
434+
pushNotificationConfig=push_config,
435+
acceptedOutputModes=['text/plain'], # Added required field
436+
)
437+
params = MessageSendParams(
438+
message=Message(
439+
role=Role.user,
440+
messageId='msg_push',
441+
parts=[],
442+
taskId=task_id,
443+
contextId=context_id,
444+
),
445+
configuration=message_config,
446+
)
447+
448+
# Mock ResultAggregator and its consume_and_break_on_interrupt
449+
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
450+
final_task_result = create_sample_task(
451+
task_id=task_id, context_id=context_id, status_state=TaskState.completed
452+
)
453+
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
454+
final_task_result,
455+
False,
456+
)
457+
458+
# Mock the current_result property to return the final task result
459+
async def get_current_result():
460+
return final_task_result
461+
462+
# Configure the 'current_result' property on the type of the mock instance
463+
type(mock_result_aggregator_instance).current_result = PropertyMock(
464+
return_value=get_current_result()
465+
)
466+
467+
with (
468+
patch(
469+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
470+
return_value=mock_result_aggregator_instance,
471+
),
472+
patch(
473+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
474+
return_value=None,
475+
),
476+
):
477+
await request_handler.on_message_send(
478+
params, create_server_call_context()
479+
)
480+
481+
mock_push_notification_store.set_info.assert_awaited_once_with(
482+
task_id, push_config
483+
)
484+
# Other assertions for full flow if needed (e.g., agent execution)
485+
mock_agent_executor.execute.assert_awaited_once()
486+
487+
404488
@pytest.mark.asyncio
405489
async def test_on_message_send_no_result_from_aggregator():
406490
"""Test on_message_send when aggregator returns (None, False)."""

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ async def streaming_coro():
585585
'kind': 'task',
586586
'status': {'state': 'submitted'},
587587
},
588+
headers=None
588589
),
589590
call(
590591
'http://example.com',
@@ -605,6 +606,7 @@ async def streaming_coro():
605606
'kind': 'task',
606607
'status': {'state': 'submitted'},
607608
},
609+
headers=None
608610
),
609611
call(
610612
'http://example.com',
@@ -625,6 +627,7 @@ async def streaming_coro():
625627
'kind': 'task',
626628
'status': {'state': 'completed'},
627629
},
630+
headers=None
628631
),
629632
]
630633
mock_httpx_client.post.assert_has_calls(calls)

tests/server/tasks/test_inmemory_push_notifications.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def create_sample_task(task_id='task123', status_state=TaskState.completed):
2626

2727

2828
def create_sample_push_config(
29-
url='http://example.com/callback', config_id='cfg1'
29+
url='http://example.com/callback', config_id='cfg1', token=None
3030
):
31-
return PushNotificationConfig(id=config_id, url=url)
31+
return PushNotificationConfig(id=config_id, url=url, token=token)
3232

3333

3434
class TestInMemoryPushNotifier(unittest.IsolatedAsyncioTestCase):
@@ -158,6 +158,35 @@ async def test_send_notification_success(self):
158158
) # auth is not passed by current implementation
159159
mock_response.raise_for_status.assert_called_once()
160160

161+
async def test_send_notification_with_token_success(self):
162+
task_id = 'task_send_success'
163+
task_data = create_sample_task(task_id=task_id)
164+
config = create_sample_push_config(url='http://notify.me/here', token='unique_token')
165+
await self.config_store.set_info(task_id, config)
166+
167+
# Mock the post call to simulate success
168+
mock_response = AsyncMock(spec=httpx.Response)
169+
mock_response.status_code = 200
170+
self.mock_httpx_client.post.return_value = mock_response
171+
172+
await self.notifier.send_notification(task_data) # Pass only task_data
173+
174+
self.mock_httpx_client.post.assert_awaited_once()
175+
called_args, called_kwargs = self.mock_httpx_client.post.call_args
176+
self.assertEqual(called_args[0], config.url)
177+
self.assertEqual(
178+
called_kwargs['json'],
179+
task_data.model_dump(mode='json', exclude_none=True),
180+
)
181+
self.assertEqual(
182+
called_kwargs['headers'],
183+
{"X-A2A-Notification-Token": "unique_token"},
184+
)
185+
self.assertNotIn(
186+
'auth', called_kwargs
187+
) # auth is not passed by current implementation
188+
mock_response.raise_for_status.assert_called_once()
189+
161190
async def test_send_notification_no_config(self):
162191
task_id = 'task_send_no_config'
163192
task_data = create_sample_task(task_id=task_id)

tests/server/tasks/test_push_notification_sender.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ def create_sample_task(task_id='task123', status_state=TaskState.completed):
2525

2626

2727
def create_sample_push_config(
28-
url='http://example.com/callback', config_id='cfg1'
28+
url='http://example.com/callback', config_id='cfg1', token=None
2929
):
30-
return PushNotificationConfig(id=config_id, url=url)
30+
return PushNotificationConfig(id=config_id, url=url, token=token)
3131

3232

3333
class TestBasePushNotificationSender(unittest.IsolatedAsyncioTestCase):
@@ -61,6 +61,29 @@ async def test_send_notification_success(self):
6161
self.mock_httpx_client.post.assert_awaited_once_with(
6262
config.url,
6363
json=task_data.model_dump(mode='json', exclude_none=True),
64+
headers=None
65+
)
66+
mock_response.raise_for_status.assert_called_once()
67+
68+
async def test_send_notification_with_token_success(self):
69+
task_id = 'task_send_success'
70+
task_data = create_sample_task(task_id=task_id)
71+
config = create_sample_push_config(url='http://notify.me/here', token='unique_token')
72+
self.mock_config_store.get_info.return_value = [config]
73+
74+
mock_response = AsyncMock(spec=httpx.Response)
75+
mock_response.status_code = 200
76+
self.mock_httpx_client.post.return_value = mock_response
77+
78+
await self.sender.send_notification(task_data)
79+
80+
self.mock_config_store.get_info.assert_awaited_once_with
81+
82+
# assert httpx_client post method got invoked with right parameters
83+
self.mock_httpx_client.post.assert_awaited_once_with(
84+
config.url,
85+
json=task_data.model_dump(mode='json', exclude_none=True),
86+
headers={'X-A2A-Notification-Token': 'unique_token'}
6487
)
6588
mock_response.raise_for_status.assert_called_once()
6689

@@ -97,6 +120,7 @@ async def test_send_notification_http_status_error(
97120
self.mock_httpx_client.post.assert_awaited_once_with(
98121
config.url,
99122
json=task_data.model_dump(mode='json', exclude_none=True),
123+
headers=None
100124
)
101125
mock_logger.error.assert_called_once()
102126

@@ -124,10 +148,12 @@ async def test_send_notification_multiple_configs(self):
124148
self.mock_httpx_client.post.assert_any_call(
125149
config1.url,
126150
json=task_data.model_dump(mode='json', exclude_none=True),
151+
headers=None
127152
)
128153
# Check calls for config2
129154
self.mock_httpx_client.post.assert_any_call(
130155
config2.url,
131156
json=task_data.model_dump(mode='json', exclude_none=True),
157+
headers=None
132158
)
133159
mock_response.raise_for_status.call_count = 2

0 commit comments

Comments
 (0)