Skip to content

Commit 574f31c

Browse files
committed
fix: Send push notifications for message/send
1 parent 75aa4ed commit 574f31c

File tree

2 files changed

+96
-22
lines changed

2 files changed

+96
-22
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,6 @@ async def _setup_message_execution(
202202
)
203203

204204
task = task_manager.update_with_message(params.message, task)
205-
if self.should_add_push_info(params):
206-
assert self._push_config_store is not None
207-
assert isinstance(
208-
params.configuration, MessageSendConfiguration
209-
)
210-
assert isinstance(
211-
params.configuration.pushNotificationConfig,
212-
PushNotificationConfig,
213-
)
214-
await self._push_config_store.set_info(
215-
task.id, params.configuration.pushNotificationConfig
216-
)
217205

218206
# Build request context
219207
request_context = await self._request_context_builder.build(
@@ -228,6 +216,18 @@ async def _setup_message_execution(
228216
# Always assign a task ID. We may not actually upgrade to a task, but
229217
# dictating the task ID at this layer is useful for tracking running
230218
# agents.
219+
220+
if self.should_add_push_info(params):
221+
assert self._push_config_store is not None
222+
assert isinstance(params.configuration, MessageSendConfiguration)
223+
assert isinstance(
224+
params.configuration.pushNotificationConfig,
225+
PushNotificationConfig,
226+
)
227+
await self._push_config_store.set_info(
228+
task_id, params.configuration.pushNotificationConfig
229+
)
230+
231231
queue = await self._queue_manager.create_or_tap(task_id)
232232
result_aggregator = ResultAggregator(task_manager)
233233
# TODO: to manage the non-blocking flows.
@@ -333,16 +333,6 @@ async def on_message_send_stream(
333333
if isinstance(event, Task):
334334
self._validate_task_id_match(task_id, event.id)
335335

336-
if (
337-
self._push_config_store
338-
and params.configuration
339-
and params.configuration.pushNotificationConfig
340-
):
341-
await self._push_config_store.set_info(
342-
task_id,
343-
params.configuration.pushNotificationConfig,
344-
)
345-
346336
await self._send_push_notification_if_needed(
347337
task_id, result_aggregator
348338
)

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,90 @@ async def get_current_result():
401401
mock_agent_executor.execute.assert_awaited_once()
402402

403403

404+
@pytest.mark.asyncio
405+
async def test_on_message_send_with_push_notification_no_existing_Task():
406+
"""Test on_message_send for new task sets push notification info if provided."""
407+
mock_task_store = AsyncMock(spec=TaskStore)
408+
mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore)
409+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
410+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
411+
412+
task_id = 'push_task_1'
413+
context_id = 'push_ctx_1'
414+
415+
mock_task_store.get.return_value = (
416+
None # Simulate new task scenario for TaskManager
417+
)
418+
419+
# Mock _request_context_builder.build to return a context with the generated/confirmed IDs
420+
mock_request_context = MagicMock(spec=RequestContext)
421+
mock_request_context.task_id = task_id
422+
mock_request_context.context_id = context_id
423+
mock_request_context_builder.build.return_value = mock_request_context
424+
425+
request_handler = DefaultRequestHandler(
426+
agent_executor=mock_agent_executor,
427+
task_store=mock_task_store,
428+
push_config_store=mock_push_notification_store,
429+
request_context_builder=mock_request_context_builder,
430+
)
431+
432+
push_config = PushNotificationConfig(url='http://callback.com/push')
433+
message_config = MessageSendConfiguration(
434+
pushNotificationConfig=push_config,
435+
acceptedOutputModes=['text/plain'], # Added required field
436+
)
437+
params = MessageSendParams(
438+
message=Message(
439+
role=Role.user,
440+
messageId='msg_push',
441+
parts=[],
442+
taskId=task_id,
443+
contextId=context_id,
444+
),
445+
configuration=message_config,
446+
)
447+
448+
# Mock ResultAggregator and its consume_and_break_on_interrupt
449+
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
450+
final_task_result = create_sample_task(
451+
task_id=task_id, context_id=context_id, status_state=TaskState.completed
452+
)
453+
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
454+
final_task_result,
455+
False,
456+
)
457+
458+
# Mock the current_result property to return the final task result
459+
async def get_current_result():
460+
return final_task_result
461+
462+
# Configure the 'current_result' property on the type of the mock instance
463+
type(mock_result_aggregator_instance).current_result = PropertyMock(
464+
return_value=get_current_result()
465+
)
466+
467+
with (
468+
patch(
469+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
470+
return_value=mock_result_aggregator_instance,
471+
),
472+
patch(
473+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
474+
return_value=None,
475+
),
476+
):
477+
await request_handler.on_message_send(
478+
params, create_server_call_context()
479+
)
480+
481+
mock_push_notification_store.set_info.assert_awaited_once_with(
482+
task_id, push_config
483+
)
484+
# Other assertions for full flow if needed (e.g., agent execution)
485+
mock_agent_executor.execute.assert_awaited_once()
486+
487+
404488
@pytest.mark.asyncio
405489
async def test_on_message_send_no_result_from_aggregator():
406490
"""Test on_message_send when aggregator returns (None, False)."""

0 commit comments

Comments
 (0)