@@ -401,6 +401,90 @@ async def get_current_result():
401401 mock_agent_executor .execute .assert_awaited_once ()
402402
403403
404+ @pytest .mark .asyncio
405+ async def test_on_message_send_with_push_notification_no_existing_Task ():
406+ """Test on_message_send for new task sets push notification info if provided."""
407+ mock_task_store = AsyncMock (spec = TaskStore )
408+ mock_push_notification_store = AsyncMock (spec = PushNotificationConfigStore )
409+ mock_agent_executor = AsyncMock (spec = AgentExecutor )
410+ mock_request_context_builder = AsyncMock (spec = RequestContextBuilder )
411+
412+ task_id = 'push_task_1'
413+ context_id = 'push_ctx_1'
414+
415+ mock_task_store .get .return_value = (
416+ None # Simulate new task scenario for TaskManager
417+ )
418+
419+ # Mock _request_context_builder.build to return a context with the generated/confirmed IDs
420+ mock_request_context = MagicMock (spec = RequestContext )
421+ mock_request_context .task_id = task_id
422+ mock_request_context .context_id = context_id
423+ mock_request_context_builder .build .return_value = mock_request_context
424+
425+ request_handler = DefaultRequestHandler (
426+ agent_executor = mock_agent_executor ,
427+ task_store = mock_task_store ,
428+ push_config_store = mock_push_notification_store ,
429+ request_context_builder = mock_request_context_builder ,
430+ )
431+
432+ push_config = PushNotificationConfig (url = 'http://callback.com/push' )
433+ message_config = MessageSendConfiguration (
434+ pushNotificationConfig = push_config ,
435+ acceptedOutputModes = ['text/plain' ], # Added required field
436+ )
437+ params = MessageSendParams (
438+ message = Message (
439+ role = Role .user ,
440+ messageId = 'msg_push' ,
441+ parts = [],
442+ taskId = task_id ,
443+ contextId = context_id ,
444+ ),
445+ configuration = message_config ,
446+ )
447+
448+ # Mock ResultAggregator and its consume_and_break_on_interrupt
449+ mock_result_aggregator_instance = AsyncMock (spec = ResultAggregator )
450+ final_task_result = create_sample_task (
451+ task_id = task_id , context_id = context_id , status_state = TaskState .completed
452+ )
453+ mock_result_aggregator_instance .consume_and_break_on_interrupt .return_value = (
454+ final_task_result ,
455+ False ,
456+ )
457+
458+ # Mock the current_result property to return the final task result
459+ async def get_current_result ():
460+ return final_task_result
461+
462+ # Configure the 'current_result' property on the type of the mock instance
463+ type(mock_result_aggregator_instance ).current_result = PropertyMock (
464+ return_value = get_current_result ()
465+ )
466+
467+ with (
468+ patch (
469+ 'a2a.server.request_handlers.default_request_handler.ResultAggregator' ,
470+ return_value = mock_result_aggregator_instance ,
471+ ),
472+ patch (
473+ 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task' ,
474+ return_value = None ,
475+ ),
476+ ):
477+ await request_handler .on_message_send (
478+ params , create_server_call_context ()
479+ )
480+
481+ mock_push_notification_store .set_info .assert_awaited_once_with (
482+ task_id , push_config
483+ )
484+ # Other assertions for full flow if needed (e.g., agent execution)
485+ mock_agent_executor .execute .assert_awaited_once ()
486+
487+
404488@pytest .mark .asyncio
405489async def test_on_message_send_no_result_from_aggregator ():
406490 """Test on_message_send when aggregator returns (None, False)."""
0 commit comments