diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 6a38933f..724fe61e 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -109,7 +109,7 @@ async def on_get_task( context: ServerCallContext | None = None, ) -> Task | None: """Default handler for 'tasks/get'.""" - task: Task | None = await self.task_store.get(params.id) + task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) @@ -141,7 +141,7 @@ async def on_cancel_task( Attempts to cancel the task managed by the `AgentExecutor`. """ - task: Task | None = await self.task_store.get(params.id) + task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) @@ -158,6 +158,7 @@ async def on_cancel_task( context_id=task.context_id, task_store=self.task_store, initial_message=None, + context=context, ) result_aggregator = ResultAggregator(task_manager) @@ -224,6 +225,7 @@ async def _setup_message_execution( context_id=params.message.context_id, task_store=self.task_store, initial_message=params.message, + context=context, ) task: Task | None = await task_manager.get_task() @@ -424,7 +426,7 @@ async def on_set_task_push_notification_config( if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.task_id) + task: Task | None = await self.task_store.get(params.task_id, context) if not task: raise ServerError(error=TaskNotFoundError()) @@ -447,7 +449,7 @@ async def on_get_task_push_notification_config( if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.id) + task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) @@ -476,7 +478,7 @@ async def on_resubscribe_to_task( Allows a client to re-attach to a running streaming task's event stream. Requires the task and its queue to still be active. """ - task: Task | None = await self.task_store.get(params.id) + task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) @@ -492,6 +494,7 @@ async def on_resubscribe_to_task( context_id=task.context_id, task_store=self.task_store, initial_message=None, + context=context, ) result_aggregator = ResultAggregator(task_manager) @@ -516,7 +519,7 @@ async def on_list_task_push_notification_config( if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.id) + task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) @@ -543,7 +546,7 @@ async def on_delete_task_push_notification_config( if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.id) + task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index b46d7193..07ba7e97 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -19,6 +19,7 @@ "or 'pip install a2a-sdk[sql]'" ) from e +from a2a.server.context import ServerCallContext from a2a.server.models import Base, TaskModel, create_task_model from a2a.server.tasks.task_store import TaskStore from a2a.types import Task # Task is the Pydantic model @@ -119,7 +120,9 @@ def _from_orm(self, task_model: TaskModel) -> Task: # Pydantic's model_validate will parse the nested dicts/lists from JSON return Task.model_validate(task_data_from_db) - async def save(self, task: Task) -> None: + async def save( + self, task: Task, context: ServerCallContext | None = None + ) -> None: """Saves or updates a task in the database.""" await self._ensure_initialized() db_task = self._to_orm(task) @@ -127,7 +130,9 @@ async def save(self, task: Task) -> None: await session.merge(db_task) logger.debug('Task %s saved/updated successfully.', task.id) - async def get(self, task_id: str) -> Task | None: + async def get( + self, task_id: str, context: ServerCallContext | None = None + ) -> Task | None: """Retrieves a task from the database by ID.""" await self._ensure_initialized() async with self.async_session_maker() as session: @@ -142,7 +147,9 @@ async def get(self, task_id: str) -> Task | None: logger.debug('Task %s not found in store.', task_id) return None - async def delete(self, task_id: str) -> None: + async def delete( + self, task_id: str, context: ServerCallContext | None = None + ) -> None: """Deletes a task from the database by ID.""" await self._ensure_initialized() diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index 26c09823..4e192af0 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -1,6 +1,7 @@ import asyncio import logging +from a2a.server.context import ServerCallContext from a2a.server.tasks.task_store import TaskStore from a2a.types import Task @@ -21,13 +22,17 @@ def __init__(self) -> None: self.tasks: dict[str, Task] = {} self.lock = asyncio.Lock() - async def save(self, task: Task) -> None: + async def save( + self, task: Task, context: ServerCallContext | None = None + ) -> None: """Saves or updates a task in the in-memory store.""" async with self.lock: self.tasks[task.id] = task logger.debug('Task %s saved successfully.', task.id) - async def get(self, task_id: str) -> Task | None: + async def get( + self, task_id: str, context: ServerCallContext | None = None + ) -> Task | None: """Retrieves a task from the in-memory store by ID.""" async with self.lock: logger.debug('Attempting to get task with id: %s', task_id) @@ -38,7 +43,9 @@ async def get(self, task_id: str) -> Task | None: logger.debug('Task %s not found in store.', task_id) return task - async def delete(self, task_id: str) -> None: + async def delete( + self, task_id: str, context: ServerCallContext | None = None + ) -> None: """Deletes a task from the in-memory store by ID.""" async with self.lock: logger.debug('Attempting to delete task with id: %s', task_id) diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index 334d9992..5c363703 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -1,5 +1,6 @@ import logging +from a2a.server.context import ServerCallContext from a2a.server.events.event_queue import Event from a2a.server.tasks.task_store import TaskStore from a2a.types import ( @@ -31,6 +32,7 @@ def __init__( context_id: str | None, task_store: TaskStore, initial_message: Message | None, + context: ServerCallContext | None = None, ): """Initializes the TaskManager. @@ -40,6 +42,7 @@ def __init__( task_store: The `TaskStore` instance for persistence. initial_message: The `Message` that initiated the task, if any. Used when creating a new task object. + context: The `ServerCallContext` that this task is produced under. """ if task_id is not None and not (isinstance(task_id, str) and task_id): raise ValueError('Task ID must be a non-empty string') @@ -49,6 +52,7 @@ def __init__( self.task_store = task_store self._initial_message = initial_message self._current_task: Task | None = None + self._call_context: ServerCallContext | None = context logger.debug( 'TaskManager initialized with task_id: %s, context_id: %s', task_id, @@ -74,7 +78,9 @@ async def get_task(self) -> Task | None: logger.debug( 'Attempting to get task from store with id: %s', self.task_id ) - self._current_task = await self.task_store.get(self.task_id) + self._current_task = await self.task_store.get( + self.task_id, self._call_context + ) if self._current_task: logger.debug('Task %s retrieved successfully.', self.task_id) else: @@ -167,7 +173,7 @@ async def ensure_task( logger.debug( 'Attempting to retrieve existing task with id: %s', self.task_id ) - task = await self.task_store.get(self.task_id) + task = await self.task_store.get(self.task_id, self._call_context) if not task: logger.info( @@ -231,7 +237,7 @@ async def _save_task(self, task: Task) -> None: task: The `Task` object to save. """ logger.debug('Saving task with id: %s', task.id) - await self.task_store.save(task) + await self.task_store.save(task, self._call_context) self._current_task = task if not self.task_id: logger.info('New task created with id: %s', task.id) diff --git a/src/a2a/server/tasks/task_store.py b/src/a2a/server/tasks/task_store.py index 1ed974a9..16b36edb 100644 --- a/src/a2a/server/tasks/task_store.py +++ b/src/a2a/server/tasks/task_store.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from a2a.server.context import ServerCallContext from a2a.types import Task @@ -10,13 +11,19 @@ class TaskStore(ABC): """ @abstractmethod - async def save(self, task: Task) -> None: + async def save( + self, task: Task, context: ServerCallContext | None = None + ) -> None: """Saves or updates a task in the store.""" @abstractmethod - async def get(self, task_id: str) -> Task | None: + async def get( + self, task_id: str, context: ServerCallContext | None = None + ) -> Task | None: """Retrieves a task from the store by ID.""" @abstractmethod - async def delete(self, task_id: str) -> None: + async def delete( + self, task_id: str, context: ServerCallContext | None = None + ) -> None: """Deletes a task from the store by ID.""" diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 88fb7d3e..f1408e36 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -135,11 +135,12 @@ async def test_on_get_task_not_found(): from a2a.utils.errors import ServerError # Local import for ServerError + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: - await request_handler.on_get_task(params, create_server_call_context()) + await request_handler.on_get_task(params, context) assert isinstance(exc_info.value.error, TaskNotFoundError) - mock_task_store.get.assert_awaited_once_with('non_existent_task') + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) @pytest.mark.asyncio @@ -155,13 +156,14 @@ async def test_on_cancel_task_task_not_found(): from a2a.utils.errors import ServerError # Local import + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: - await request_handler.on_cancel_task( - params, create_server_call_context() - ) + await request_handler.on_cancel_task(params, context) assert isinstance(exc_info.value.error, TaskNotFoundError) - mock_task_store.get.assert_awaited_once_with('task_not_found_for_cancel') + mock_task_store.get.assert_awaited_once_with( + 'task_not_found_for_cancel', context + ) @pytest.mark.asyncio @@ -195,16 +197,15 @@ async def test_on_cancel_task_queue_tap_returns_none(): queue_manager=mock_queue_manager, ) + context = create_server_call_context() with patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', return_value=mock_result_aggregator_instance, ): params = TaskIdParams(id='tap_none_task') - result_task = await request_handler.on_cancel_task( - params, create_server_call_context() - ) + result_task = await request_handler.on_cancel_task(params, context) - mock_task_store.get.assert_awaited_once_with('tap_none_task') + mock_task_store.get.assert_awaited_once_with('tap_none_task', context) mock_queue_manager.tap.assert_awaited_once_with('tap_none_task') # agent_executor.cancel should be called with a new EventQueue if tap returned None mock_agent_executor.cancel.assert_awaited_once() @@ -250,14 +251,13 @@ async def test_on_cancel_task_cancels_running_agent(): mock_producer_task = AsyncMock(spec=asyncio.Task) request_handler._running_agents[task_id] = mock_producer_task + context = create_server_call_context() with patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', return_value=mock_result_aggregator_instance, ): params = TaskIdParams(id=task_id) - await request_handler.on_cancel_task( - params, create_server_call_context() - ) + await request_handler.on_cancel_task(params, context) mock_producer_task.cancel.assert_called_once() mock_agent_executor.cancel.assert_awaited_once() @@ -1322,13 +1322,14 @@ async def test_set_task_push_notification_config_task_not_found(): ) from a2a.utils.errors import ServerError # Local import + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: await request_handler.on_set_task_push_notification_config( - params, create_server_call_context() + params, context ) assert isinstance(exc_info.value.error, TaskNotFoundError) - mock_task_store.get.assert_awaited_once_with('non_existent_task') + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) mock_push_store.set_info.assert_not_awaited() @@ -1365,13 +1366,14 @@ async def test_get_task_push_notification_config_task_not_found(): params = GetTaskPushNotificationConfigParams(id='non_existent_task') from a2a.utils.errors import ServerError # Local import + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: await request_handler.on_get_task_push_notification_config( - params, create_server_call_context() + params, context ) assert isinstance(exc_info.value.error, TaskNotFoundError) - mock_task_store.get.assert_awaited_once_with('non_existent_task') + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) mock_push_store.get_info.assert_not_awaited() @@ -1394,15 +1396,16 @@ async def test_get_task_push_notification_config_info_not_found(): params = GetTaskPushNotificationConfigParams(id='non_existent_task') from a2a.utils.errors import ServerError # Local import + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: await request_handler.on_get_task_push_notification_config( - params, create_server_call_context() + params, context ) assert isinstance( exc_info.value.error, InternalError ) # Current code raises InternalError - mock_task_store.get.assert_awaited_once_with('non_existent_task') + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) mock_push_store.get_info.assert_awaited_once_with('non_existent_task') @@ -1425,8 +1428,9 @@ async def test_get_task_push_notification_config_info_with_config(): id='config_id', url='http://1.example.com' ), ) + context = create_server_call_context() await request_handler.on_set_task_push_notification_config( - set_config_params, create_server_call_context() + set_config_params, context ) params = GetTaskPushNotificationConfigParams( @@ -1435,7 +1439,7 @@ async def test_get_task_push_notification_config_info_with_config(): result: TaskPushNotificationConfig = ( await request_handler.on_get_task_push_notification_config( - params, create_server_call_context() + params, context ) ) @@ -1501,15 +1505,16 @@ async def test_on_resubscribe_to_task_task_not_found(): from a2a.utils.errors import ServerError # Local import + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: # Need to consume the async generator to trigger the error - async for _ in request_handler.on_resubscribe_to_task( - params, create_server_call_context() - ): + async for _ in request_handler.on_resubscribe_to_task(params, context): pass assert isinstance(exc_info.value.error, TaskNotFoundError) - mock_task_store.get.assert_awaited_once_with('resub_task_not_found') + mock_task_store.get.assert_awaited_once_with( + 'resub_task_not_found', context + ) @pytest.mark.asyncio @@ -1531,16 +1536,17 @@ async def test_on_resubscribe_to_task_queue_not_found(): from a2a.utils.errors import ServerError # Local import + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: - async for _ in request_handler.on_resubscribe_to_task( - params, create_server_call_context() - ): + async for _ in request_handler.on_resubscribe_to_task(params, context): pass assert isinstance( exc_info.value.error, TaskNotFoundError ) # Should be TaskNotFoundError as per spec - mock_task_store.get.assert_awaited_once_with('resub_queue_not_found') + mock_task_store.get.assert_awaited_once_with( + 'resub_queue_not_found', context + ) mock_queue_manager.tap.assert_awaited_once_with('resub_queue_not_found') @@ -1614,13 +1620,14 @@ async def test_list_task_push_notification_config_task_not_found(): params = ListTaskPushNotificationConfigParams(id='non_existent_task') from a2a.utils.errors import ServerError # Local import + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: await request_handler.on_list_task_push_notification_config( - params, create_server_call_context() + params, context ) assert isinstance(exc_info.value.error, TaskNotFoundError) - mock_task_store.get.assert_awaited_once_with('non_existent_task') + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) mock_push_store.get_info.assert_not_awaited() @@ -1774,13 +1781,14 @@ async def test_delete_task_push_notification_config_task_not_found(): ) from a2a.utils.errors import ServerError # Local import + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: await request_handler.on_delete_task_push_notification_config( - params, create_server_call_context() + params, context ) assert isinstance(exc_info.value.error, TaskNotFoundError) - mock_task_store.get.assert_awaited_once_with('non_existent_task') + mock_task_store.get.assert_awaited_once_with('non_existent_task', context) mock_push_store.get_info.assert_not_awaited() @@ -2025,10 +2033,9 @@ async def test_on_resubscribe_to_task_in_terminal_state(terminal_state): from a2a.utils.errors import ServerError + context = create_server_call_context() with pytest.raises(ServerError) as exc_info: - async for _ in request_handler.on_resubscribe_to_task( - params, create_server_call_context() - ): + async for _ in request_handler.on_resubscribe_to_task(params, context): pass # pragma: no cover assert isinstance(exc_info.value.error, InvalidParamsError) @@ -2037,7 +2044,7 @@ async def test_on_resubscribe_to_task_in_terminal_state(terminal_state): f'Task {task_id} is in terminal state: {terminal_state.value}' in exc_info.value.error.message ) - mock_task_store.get.assert_awaited_once_with(task_id) + mock_task_store.get.assert_awaited_once_with(task_id, context) @pytest.mark.asyncio diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 1d1b3c5d..616cf131 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -1,5 +1,6 @@ import unittest import unittest.async_case + from collections.abc import AsyncGenerator from typing import Any, NoReturn from unittest.mock import AsyncMock, MagicMock, call, patch @@ -26,7 +27,6 @@ AgentCapabilities, AgentCard, Artifact, - AuthenticatedExtendedCardNotConfiguredError, CancelTaskRequest, CancelTaskSuccessResponse, DeleteTaskPushNotificationConfigParams, @@ -74,6 +74,7 @@ ) from a2a.utils.errors import ServerError + MINIMAL_TASK: dict[str, Any] = { 'id': 'task_123', 'contextId': 'session-xyz', @@ -113,7 +114,7 @@ async def test_on_get_task_success(self) -> None: ) self.assertIsInstance(response.root, GetTaskSuccessResponse) assert response.root.result == mock_task # type: ignore - mock_task_store.get.assert_called_once_with(task_id) + mock_task_store.get.assert_called_once_with(task_id, unittest.mock.ANY) async def test_on_get_task_not_found(self) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) @@ -208,7 +209,9 @@ async def test_on_cancel_task_not_found(self) -> None: response = await handler.on_cancel_task(request) self.assertIsInstance(response.root, JSONRPCErrorResponse) assert response.root.error == TaskNotFoundError() # type: ignore - mock_task_store.get.assert_called_once_with('nonexistent_id') + mock_task_store.get.assert_called_once_with( + 'nonexistent_id', unittest.mock.ANY + ) mock_agent_executor.cancel.assert_not_called() @patch( diff --git a/tests/server/tasks/test_task_manager.py b/tests/server/tasks/test_task_manager.py index 4f243157..8208ca78 100644 --- a/tests/server/tasks/test_task_manager.py +++ b/tests/server/tasks/test_task_manager.py @@ -68,7 +68,7 @@ async def test_get_task_existing( mock_task_store.get.return_value = expected_task retrieved_task = await task_manager.get_task() assert retrieved_task == expected_task - mock_task_store.get.assert_called_once_with(MINIMAL_TASK['id']) + mock_task_store.get.assert_called_once_with(MINIMAL_TASK['id'], None) @pytest.mark.asyncio @@ -79,7 +79,7 @@ async def test_get_task_nonexistent( mock_task_store.get.return_value = None retrieved_task = await task_manager.get_task() assert retrieved_task is None - mock_task_store.get.assert_called_once_with(MINIMAL_TASK['id']) + mock_task_store.get.assert_called_once_with(MINIMAL_TASK['id'], None) @pytest.mark.asyncio @@ -89,7 +89,7 @@ async def test_save_task_event_new_task( """Test saving a new task.""" task = Task(**MINIMAL_TASK) await task_manager.save_task_event(task) - mock_task_store.save.assert_called_once_with(task) + mock_task_store.save.assert_called_once_with(task, None) @pytest.mark.asyncio @@ -116,7 +116,7 @@ async def test_save_task_event_status_update( await task_manager.save_task_event(event) updated_task = initial_task updated_task.status = new_status - mock_task_store.save.assert_called_once_with(updated_task) + mock_task_store.save.assert_called_once_with(updated_task, None) @pytest.mark.asyncio @@ -139,7 +139,7 @@ async def test_save_task_event_artifact_update( await task_manager.save_task_event(event) updated_task = initial_task updated_task.artifacts = [new_artifact] - mock_task_store.save.assert_called_once_with(updated_task) + mock_task_store.save.assert_called_once_with(updated_task, None) @pytest.mark.asyncio @@ -179,7 +179,7 @@ async def test_ensure_task_existing( ) retrieved_task = await task_manager.ensure_task(event) assert retrieved_task == expected_task - mock_task_store.get.assert_called_once_with(MINIMAL_TASK['id']) + mock_task_store.get.assert_called_once_with(MINIMAL_TASK['id'], None) @pytest.mark.asyncio @@ -204,7 +204,7 @@ async def test_ensure_task_nonexistent( assert new_task.id == 'new-task' assert new_task.context_id == 'some-context' assert new_task.status.state == TaskState.submitted - mock_task_store.save.assert_called_once_with(new_task) + mock_task_store.save.assert_called_once_with(new_task, None) assert task_manager_without_id.task_id == 'new-task' assert task_manager_without_id.context_id == 'some-context' @@ -225,7 +225,7 @@ async def test_save_task( """Test saving a task.""" task = Task(**MINIMAL_TASK) await task_manager._save_task(task) # type: ignore - mock_task_store.save.assert_called_once_with(task) + mock_task_store.save.assert_called_once_with(task, None) @pytest.mark.asyncio @@ -264,7 +264,7 @@ async def test_save_task_event_new_task_no_task_id( } task = Task(**task_data) await task_manager_without_id.save_task_event(task) - mock_task_store.save.assert_called_once_with(task) + mock_task_store.save.assert_called_once_with(task, None) assert task_manager_without_id.task_id == 'new-task-id' assert task_manager_without_id.context_id == 'some-context' # initial submit should be updated to working