Skip to content

Commit b0909a3

Browse files
committed
test: add a unit test for push_notification at non-blocking request
- location: tests/server/request_handlers/test_default_request_handler.py:test_on_message_send_with_push_notification_in_non_blocking_request
1 parent 6a91e23 commit b0909a3

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,126 @@ async def get_current_result():
405405
mock_agent_executor.execute.assert_awaited_once()
406406

407407

408+
@pytest.mark.asyncio
409+
async def test_on_message_send_with_push_notification_in_non_blocking_request():
410+
"""Test that push notification callback is called during background event processing for non-blocking requests."""
411+
mock_task_store = AsyncMock(spec=TaskStore)
412+
mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore)
413+
mock_agent_executor = AsyncMock(spec=AgentExecutor)
414+
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
415+
mock_push_sender = AsyncMock()
416+
417+
task_id = 'non_blocking_task_1'
418+
context_id = 'non_blocking_ctx_1'
419+
420+
# Create a task that will be returned after the first event
421+
initial_task = create_sample_task(
422+
task_id=task_id, context_id=context_id, status_state=TaskState.working
423+
)
424+
425+
# Create a final task that will be available during background processing
426+
final_task = create_sample_task(
427+
task_id=task_id, context_id=context_id, status_state=TaskState.completed
428+
)
429+
430+
mock_task_store.get.return_value = None
431+
432+
# Mock request context
433+
mock_request_context = MagicMock(spec=RequestContext)
434+
mock_request_context.task_id = task_id
435+
mock_request_context.context_id = context_id
436+
mock_request_context_builder.build.return_value = mock_request_context
437+
438+
request_handler = DefaultRequestHandler(
439+
agent_executor=mock_agent_executor,
440+
task_store=mock_task_store,
441+
push_config_store=mock_push_notification_store,
442+
request_context_builder=mock_request_context_builder,
443+
push_sender=mock_push_sender,
444+
)
445+
446+
# Configure push notification
447+
push_config = PushNotificationConfig(url='http://callback.com/push')
448+
message_config = MessageSendConfiguration(
449+
push_notification_config=push_config,
450+
accepted_output_modes=['text/plain'],
451+
blocking=False, # Non-blocking request
452+
)
453+
params = MessageSendParams(
454+
message=Message(
455+
role=Role.user,
456+
message_id='msg_non_blocking',
457+
parts=[],
458+
task_id=task_id,
459+
context_id=context_id,
460+
),
461+
configuration=message_config,
462+
)
463+
464+
# Mock ResultAggregator with custom behavior
465+
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
466+
467+
# First call returns the initial task and indicates interruption (non-blocking)
468+
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
469+
initial_task,
470+
True, # interrupted = True for non-blocking
471+
)
472+
473+
# Mock the current_result property to return the final task
474+
async def get_current_result():
475+
return final_task
476+
477+
type(mock_result_aggregator_instance).current_result = PropertyMock(
478+
return_value=get_current_result()
479+
)
480+
481+
# Track if the event_callback was passed to consume_and_break_on_interrupt
482+
event_callback_passed = False
483+
event_callback_received = None
484+
485+
async def mock_consume_and_break_on_interrupt(consumer, blocking=True, event_callback=None):
486+
nonlocal event_callback_passed, event_callback_received
487+
event_callback_passed = event_callback is not None
488+
event_callback_received = event_callback
489+
return initial_task, True # interrupted = True for non-blocking
490+
491+
mock_result_aggregator_instance.consume_and_break_on_interrupt = mock_consume_and_break_on_interrupt
492+
493+
with (
494+
patch(
495+
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
496+
return_value=mock_result_aggregator_instance,
497+
),
498+
patch(
499+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
500+
return_value=initial_task,
501+
),
502+
patch(
503+
'a2a.server.request_handlers.default_request_handler.TaskManager.update_with_message',
504+
return_value=initial_task,
505+
),
506+
):
507+
# Execute the non-blocking request
508+
result = await request_handler.on_message_send(
509+
params, create_server_call_context()
510+
)
511+
512+
# Verify the result is the initial task (non-blocking behavior)
513+
assert result == initial_task
514+
515+
# Verify that the event_callback was passed to consume_and_break_on_interrupt
516+
assert event_callback_passed, "event_callback should have been passed to consume_and_break_on_interrupt"
517+
assert event_callback_received is not None, "event_callback should not be None"
518+
519+
# Verify that the push notification was sent with the final task
520+
mock_push_sender.send_notification.assert_called_with(final_task)
521+
522+
# Verify that the push notification config was stored
523+
mock_push_notification_store.set_info.assert_awaited_once_with(
524+
task_id, push_config
525+
)
526+
527+
408528
@pytest.mark.asyncio
409529
async def test_on_message_send_with_push_notification_no_existing_Task():
410530
"""Test on_message_send for new task sets push notification info if provided."""

0 commit comments

Comments
 (0)