diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 4ac5fb0e..cd9f8f97 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -1,8 +1,11 @@ -import uuid - from typing import Any from a2a.server.context import ServerCallContext +from a2a.server.id_generator import ( + IDGenerator, + IDGeneratorContext, + UUIDGenerator, +) from a2a.types import ( InvalidParamsError, Message, @@ -30,6 +33,8 @@ def __init__( # noqa: PLR0913 task: Task | None = None, related_tasks: list[Task] | None = None, call_context: ServerCallContext | None = None, + task_id_generator: IDGenerator | None = None, + context_id_generator: IDGenerator | None = None, ): """Initializes the RequestContext. @@ -40,6 +45,8 @@ def __init__( # noqa: PLR0913 task: The existing `Task` object retrieved from the store, if any. related_tasks: A list of other tasks related to the current request (e.g., for tool use). call_context: The server call context associated with this request. + task_id_generator: ID generator for new task IDs. Defaults to UUID generator. + context_id_generator: ID generator for new context IDs. Defaults to UUID generator. """ if related_tasks is None: related_tasks = [] @@ -49,6 +56,12 @@ def __init__( # noqa: PLR0913 self._current_task = task self._related_tasks = related_tasks self._call_context = call_context + self._task_id_generator = ( + task_id_generator if task_id_generator else UUIDGenerator() + ) + self._context_id_generator = ( + context_id_generator if context_id_generator else UUIDGenerator() + ) # If the task id and context id were provided, make sure they # match the request. Otherwise, create them if self._params: @@ -163,7 +176,9 @@ def _check_or_generate_task_id(self) -> None: return if not self._task_id and not self._params.message.task_id: - self._params.message.task_id = str(uuid.uuid4()) + self._params.message.task_id = self._task_id_generator.generate( + IDGeneratorContext(context_id=self._context_id) + ) if self._params.message.task_id: self._task_id = self._params.message.task_id @@ -173,6 +188,10 @@ def _check_or_generate_context_id(self) -> None: return if not self._context_id and not self._params.message.context_id: - self._params.message.context_id = str(uuid.uuid4()) + self._params.message.context_id = ( + self._context_id_generator.generate( + IDGeneratorContext(task_id=self._task_id) + ) + ) if self._params.message.context_id: self._context_id = self._params.message.context_id diff --git a/src/a2a/server/id_generator.py b/src/a2a/server/id_generator.py new file mode 100644 index 00000000..c523adc9 --- /dev/null +++ b/src/a2a/server/id_generator.py @@ -0,0 +1,28 @@ +import uuid + +from abc import ABC, abstractmethod + +from pydantic import BaseModel + + +class IDGeneratorContext(BaseModel): + """Context for providing additional information to ID generators.""" + + task_id: str | None = None + context_id: str | None = None + + +class IDGenerator(ABC): + """Interface for generating unique identifiers.""" + + @abstractmethod + def generate(self, context: IDGeneratorContext) -> str: + pass + + +class UUIDGenerator(IDGenerator): + """UUID implementation of the IDGenerator interface.""" + + def generate(self, context: IDGeneratorContext) -> str: + """Generates a random UUID, ignoring the context.""" + return str(uuid.uuid4()) diff --git a/src/a2a/server/tasks/task_updater.py b/src/a2a/server/tasks/task_updater.py index 4afc8b35..b61ab700 100644 --- a/src/a2a/server/tasks/task_updater.py +++ b/src/a2a/server/tasks/task_updater.py @@ -1,10 +1,14 @@ import asyncio -import uuid from datetime import datetime, timezone from typing import Any from a2a.server.events import EventQueue +from a2a.server.id_generator import ( + IDGenerator, + IDGeneratorContext, + UUIDGenerator, +) from a2a.types import ( Artifact, Message, @@ -23,13 +27,22 @@ class TaskUpdater: Simplifies the process of creating and enqueueing standard task events. """ - def __init__(self, event_queue: EventQueue, task_id: str, context_id: str): + def __init__( + self, + event_queue: EventQueue, + task_id: str, + context_id: str, + artifact_id_generator: IDGenerator | None = None, + message_id_generator: IDGenerator | None = None, + ): """Initializes the TaskUpdater. Args: event_queue: The `EventQueue` associated with the task. task_id: The ID of the task. context_id: The context ID of the task. + artifact_id_generator: ID generator for new artifact IDs. Defaults to UUID generator. + message_id_generator: ID generator for new message IDs. Defaults to UUID generator. """ self.event_queue = event_queue self.task_id = task_id @@ -42,6 +55,12 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str): TaskState.failed, TaskState.rejected, } + self._artifact_id_generator = ( + artifact_id_generator if artifact_id_generator else UUIDGenerator() + ) + self._message_id_generator = ( + message_id_generator if message_id_generator else UUIDGenerator() + ) async def update_status( self, @@ -110,7 +129,11 @@ async def add_artifact( # noqa: PLR0913 extensions: Optional list of extensions for the artifact. """ if not artifact_id: - artifact_id = str(uuid.uuid4()) + artifact_id = self._artifact_id_generator.generate( + IDGeneratorContext( + task_id=self.task_id, context_id=self.context_id + ) + ) await self.event_queue.enqueue_event( TaskArtifactUpdateEvent( @@ -205,7 +228,11 @@ def new_agent_message( role=Role.agent, task_id=self.task_id, context_id=self.context_id, - message_id=str(uuid.uuid4()), + message_id=self._message_id_generator.generate( + IDGeneratorContext( + task_id=self.task_id, context_id=self.context_id + ) + ), metadata=metadata, parts=parts, ) diff --git a/tests/server/agent_execution/test_context.py b/tests/server/agent_execution/test_context.py index 5cecd892..684aecb2 100644 --- a/tests/server/agent_execution/test_context.py +++ b/tests/server/agent_execution/test_context.py @@ -6,6 +6,7 @@ from a2a.server.agent_execution import RequestContext from a2a.server.context import ServerCallContext +from a2a.server.id_generator import IDGenerator from a2a.types import ( Message, MessageSendParams, @@ -149,6 +150,20 @@ def test_check_or_generate_task_id_with_existing_task_id(self, mock_params): assert context.task_id == existing_id assert mock_params.message.task_id == existing_id + def test_check_or_generate_task_id_with_custom_id_generator( + self, mock_params + ): + """Test _check_or_generate_task_id uses custom ID generator when provided.""" + id_generator = Mock(spec=IDGenerator) + id_generator.generate.return_value = 'custom-task-id' + + context = RequestContext( + request=mock_params, task_id_generator=id_generator + ) + # The method is called during initialization + + assert context.task_id == 'custom-task-id' + def test_check_or_generate_context_id_no_params(self): """Test _check_or_generate_context_id with no params does nothing.""" context = RequestContext() @@ -168,6 +183,20 @@ def test_check_or_generate_context_id_with_existing_context_id( assert context.context_id == existing_id assert mock_params.message.context_id == existing_id + def test_check_or_generate_context_id_with_custom_id_generator( + self, mock_params + ): + """Test _check_or_generate_context_id uses custom ID generator when provided.""" + id_generator = Mock(spec=IDGenerator) + id_generator.generate.return_value = 'custom-context-id' + + context = RequestContext( + request=mock_params, context_id_generator=id_generator + ) + # The method is called during initialization + + assert context.context_id == 'custom-context-id' + def test_init_raises_error_on_task_id_mismatch( self, mock_params, mock_task ): diff --git a/tests/server/tasks/test_task_updater.py b/tests/server/tasks/test_task_updater.py index 844470cb..a8de65e3 100644 --- a/tests/server/tasks/test_task_updater.py +++ b/tests/server/tasks/test_task_updater.py @@ -1,11 +1,12 @@ import asyncio import uuid -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from a2a.server.events import EventQueue +from a2a.server.id_generator import IDGenerator from a2a.server.tasks import TaskUpdater from a2a.types import ( Message, @@ -151,6 +152,26 @@ async def test_add_artifact_generates_id( assert event.last_chunk is None +@pytest.mark.asyncio +async def test_add_artifact_generates_custom_id(event_queue, sample_parts): + """Test add_artifact uses a custom ID generator when provided.""" + artifact_id_generator = Mock(spec=IDGenerator) + artifact_id_generator.generate.return_value = 'custom-artifact-id' + task_updater = TaskUpdater( + event_queue=event_queue, + task_id='test-task-id', + context_id='test-context-id', + artifact_id_generator=artifact_id_generator, + ) + + await task_updater.add_artifact(parts=sample_parts, artifact_id=None) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + assert isinstance(event, TaskArtifactUpdateEvent) + assert event.artifact.artifact_id == 'custom-artifact-id' + + @pytest.mark.asyncio @pytest.mark.parametrize( 'append_val, last_chunk_val', @@ -304,6 +325,22 @@ def test_new_agent_message_with_metadata(task_updater, sample_parts): assert message.metadata == metadata +def test_new_agent_message_with_custom_id_generator(event_queue, sample_parts): + """Test creating a new agent message with a custom message ID generator.""" + message_id_generator = Mock(spec=IDGenerator) + message_id_generator.generate.return_value = 'custom-message-id' + task_updater = TaskUpdater( + event_queue=event_queue, + task_id='test-task-id', + context_id='test-context-id', + message_id_generator=message_id_generator, + ) + + message = task_updater.new_agent_message(parts=sample_parts) + + assert message.message_id == 'custom-message-id' + + @pytest.mark.asyncio async def test_failed_without_message(task_updater, event_queue): """Test marking a task as failed without a message."""