Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 10 additions & 32 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
InvalidParamsError,
ListTaskPushNotificationConfigParams,
Message,
MessageSendConfiguration,
MessageSendParams,
PushNotificationConfig,
Task,
TaskIdParams,
TaskNotFoundError,
Expand Down Expand Up @@ -202,18 +200,6 @@ async def _setup_message_execution(
)

task = task_manager.update_with_message(params.message, task)
if self.should_add_push_info(params):
assert self._push_config_store is not None
assert isinstance(
params.configuration, MessageSendConfiguration
)
assert isinstance(
params.configuration.pushNotificationConfig,
PushNotificationConfig,
)
await self._push_config_store.set_info(
task.id, params.configuration.pushNotificationConfig
)

# Build request context
request_context = await self._request_context_builder.build(
Expand All @@ -228,6 +214,16 @@ async def _setup_message_execution(
# Always assign a task ID. We may not actually upgrade to a task, but
# dictating the task ID at this layer is useful for tracking running
# agents.

if (
self._push_config_store
and params.configuration
and params.configuration.pushNotificationConfig
):
await self._push_config_store.set_info(
task_id, params.configuration.pushNotificationConfig
)

queue = await self._queue_manager.create_or_tap(task_id)
result_aggregator = ResultAggregator(task_manager)
# TODO: to manage the non-blocking flows.
Expand Down Expand Up @@ -333,16 +329,6 @@ async def on_message_send_stream(
if isinstance(event, Task):
self._validate_task_id_match(task_id, event.id)

if (
self._push_config_store
and params.configuration
and params.configuration.pushNotificationConfig
):
await self._push_config_store.set_info(
task_id,
params.configuration.pushNotificationConfig,
)

await self._send_push_notification_if_needed(
task_id, result_aggregator
)
Expand Down Expand Up @@ -509,11 +495,3 @@ async def on_delete_task_push_notification_config(
await self._push_config_store.delete_info(
params.id, params.pushNotificationConfigId
)

def should_add_push_info(self, params: MessageSendParams) -> bool:
"""Determines if push notification info should be set for a task."""
return bool(
self._push_config_store
and params.configuration
and params.configuration.pushNotificationConfig
)
84 changes: 84 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,90 @@ async def get_current_result():
mock_agent_executor.execute.assert_awaited_once()


@pytest.mark.asyncio
async def test_on_message_send_with_push_notification_no_existing_Task():
"""Test on_message_send for new task sets push notification info if provided."""
mock_task_store = AsyncMock(spec=TaskStore)
mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore)
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)

task_id = 'push_task_1'
context_id = 'push_ctx_1'

mock_task_store.get.return_value = (
None # Simulate new task scenario for TaskManager
)

# Mock _request_context_builder.build to return a context with the generated/confirmed IDs
mock_request_context = MagicMock(spec=RequestContext)
mock_request_context.task_id = task_id
mock_request_context.context_id = context_id
mock_request_context_builder.build.return_value = mock_request_context

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
push_config_store=mock_push_notification_store,
request_context_builder=mock_request_context_builder,
)

push_config = PushNotificationConfig(url='http://callback.com/push')
message_config = MessageSendConfiguration(
pushNotificationConfig=push_config,
acceptedOutputModes=['text/plain'], # Added required field
)
params = MessageSendParams(
message=Message(
role=Role.user,
messageId='msg_push',
parts=[],
taskId=task_id,
contextId=context_id,
),
configuration=message_config,
)

# Mock ResultAggregator and its consume_and_break_on_interrupt
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)
final_task_result = create_sample_task(
task_id=task_id, context_id=context_id, status_state=TaskState.completed
)
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
final_task_result,
False,
)

# Mock the current_result property to return the final task result
async def get_current_result():
return final_task_result

# Configure the 'current_result' property on the type of the mock instance
type(mock_result_aggregator_instance).current_result = PropertyMock(
return_value=get_current_result()
)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=None,
),
):
await request_handler.on_message_send(
params, create_server_call_context()
)

mock_push_notification_store.set_info.assert_awaited_once_with(
task_id, push_config
)
# Other assertions for full flow if needed (e.g., agent execution)
mock_agent_executor.execute.assert_awaited_once()


@pytest.mark.asyncio
async def test_on_message_send_no_result_from_aggregator():
"""Test on_message_send when aggregator returns (None, False)."""
Expand Down
Loading