Skip to content

Commit 4e42d31

Browse files
author
lkawka
committed
feat: custom ID generators
1 parent 74d31e3 commit 4e42d31

File tree

5 files changed

+149
-9
lines changed

5 files changed

+149
-9
lines changed

src/a2a/server/agent_execution/context.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
import uuid
2-
31
from typing import Any
42

53
from a2a.server.context import ServerCallContext
4+
from a2a.server.id_generator import (
5+
IDGenerator,
6+
IDGeneratorContext,
7+
UUIDGenerator,
8+
)
69
from a2a.types import (
710
InvalidParamsError,
811
Message,
@@ -30,6 +33,8 @@ def __init__( # noqa: PLR0913
3033
task: Task | None = None,
3134
related_tasks: list[Task] | None = None,
3235
call_context: ServerCallContext | None = None,
36+
task_id_generator: IDGenerator | None = None,
37+
context_id_generator: IDGenerator | None = None,
3338
):
3439
"""Initializes the RequestContext.
3540
@@ -40,6 +45,8 @@ def __init__( # noqa: PLR0913
4045
task: The existing `Task` object retrieved from the store, if any.
4146
related_tasks: A list of other tasks related to the current request (e.g., for tool use).
4247
call_context: The server call context associated with this request.
48+
task_id_generator: ID generator for new task IDs. Defaults to UUID generator.
49+
context_id_generator: ID generator for new context IDs. Defaults to UUID generator.
4350
"""
4451
if related_tasks is None:
4552
related_tasks = []
@@ -49,6 +56,12 @@ def __init__( # noqa: PLR0913
4956
self._current_task = task
5057
self._related_tasks = related_tasks
5158
self._call_context = call_context
59+
self._task_id_generator = (
60+
task_id_generator if task_id_generator else UUIDGenerator()
61+
)
62+
self.context_id_generator = (
63+
context_id_generator if context_id_generator else UUIDGenerator()
64+
)
5265
# If the task id and context id were provided, make sure they
5366
# match the request. Otherwise, create them
5467
if self._params:
@@ -163,7 +176,9 @@ def _check_or_generate_task_id(self) -> None:
163176
return
164177

165178
if not self._task_id and not self._params.message.task_id:
166-
self._params.message.task_id = str(uuid.uuid4())
179+
self._params.message.task_id = self._task_id_generator.generate(
180+
IDGeneratorContext(context_id=self._context_id)
181+
)
167182
if self._params.message.task_id:
168183
self._task_id = self._params.message.task_id
169184

@@ -173,6 +188,10 @@ def _check_or_generate_context_id(self) -> None:
173188
return
174189

175190
if not self._context_id and not self._params.message.context_id:
176-
self._params.message.context_id = str(uuid.uuid4())
191+
self._params.message.context_id = (
192+
self.context_id_generator.generate(
193+
IDGeneratorContext(task_id=self._task_id)
194+
)
195+
)
177196
if self._params.message.context_id:
178197
self._context_id = self._params.message.context_id

src/a2a/server/id_generator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import uuid
2+
3+
from abc import ABC, abstractmethod
4+
5+
from pydantic import BaseModel
6+
7+
8+
class IDGeneratorContext(BaseModel):
9+
"""Context for providing additional information to ID generators."""
10+
11+
task_id: str | None = None
12+
context_id: str | None = None
13+
14+
15+
class IDGenerator(ABC):
16+
"""Interface for generating unique identifiers."""
17+
18+
@abstractmethod
19+
def generate(self, context: IDGeneratorContext) -> str:
20+
pass
21+
22+
23+
class UUIDGenerator(IDGenerator):
24+
"""UUID implementation of the IDGenerator interface."""
25+
26+
def generate(self, context: IDGeneratorContext) -> str:
27+
"""Generates a random UUID."""
28+
return str(uuid.uuid4())

src/a2a/server/tasks/task_updater.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import asyncio
2-
import uuid
32

43
from datetime import datetime, timezone
54
from typing import Any
65

76
from a2a.server.events import EventQueue
7+
from a2a.server.id_generator import (
8+
IDGenerator,
9+
IDGeneratorContext,
10+
UUIDGenerator,
11+
)
812
from a2a.types import (
913
Artifact,
1014
Message,
@@ -23,13 +27,22 @@ class TaskUpdater:
2327
Simplifies the process of creating and enqueueing standard task events.
2428
"""
2529

26-
def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
30+
def __init__(
31+
self,
32+
event_queue: EventQueue,
33+
task_id: str,
34+
context_id: str,
35+
artifact_id_generator: IDGenerator | None = None,
36+
message_id_generator: IDGenerator | None = None,
37+
):
2738
"""Initializes the TaskUpdater.
2839
2940
Args:
3041
event_queue: The `EventQueue` associated with the task.
3142
task_id: The ID of the task.
3243
context_id: The context ID of the task.
44+
artifact_id_generator: ID generator for new artifact IDs. Defaults to UUID generator.
45+
message_id_generator: ID generator for new message IDs. Defaults to UUID generator.
3346
"""
3447
self.event_queue = event_queue
3548
self.task_id = task_id
@@ -42,6 +55,12 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
4255
TaskState.failed,
4356
TaskState.rejected,
4457
}
58+
self._artifact_id_generator = (
59+
artifact_id_generator if artifact_id_generator else UUIDGenerator()
60+
)
61+
self._message_id_generator = (
62+
message_id_generator if message_id_generator else UUIDGenerator()
63+
)
4564

4665
async def update_status(
4766
self,
@@ -110,7 +129,11 @@ async def add_artifact( # noqa: PLR0913
110129
extensions: Optional list of extensions for the artifact.
111130
"""
112131
if not artifact_id:
113-
artifact_id = str(uuid.uuid4())
132+
artifact_id = self._artifact_id_generator.generate(
133+
IDGeneratorContext(
134+
task_id=self.task_id, context_id=self.context_id
135+
)
136+
)
114137

115138
await self.event_queue.enqueue_event(
116139
TaskArtifactUpdateEvent(
@@ -205,7 +228,11 @@ def new_agent_message(
205228
role=Role.agent,
206229
task_id=self.task_id,
207230
context_id=self.context_id,
208-
message_id=str(uuid.uuid4()),
231+
message_id=self._message_id_generator.generate(
232+
IDGeneratorContext(
233+
task_id=self.task_id, context_id=self.context_id
234+
)
235+
),
209236
metadata=metadata,
210237
parts=parts,
211238
)

tests/server/agent_execution/test_context.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from a2a.server.agent_execution import RequestContext
88
from a2a.server.context import ServerCallContext
9+
from a2a.server.id_generator import IDGenerator
910
from a2a.types import (
1011
Message,
1112
MessageSendParams,
@@ -149,6 +150,20 @@ def test_check_or_generate_task_id_with_existing_task_id(self, mock_params):
149150
assert context.task_id == existing_id
150151
assert mock_params.message.task_id == existing_id
151152

153+
def test_check_or_generate_task_id_with_custom_id_generator(
154+
self, mock_params
155+
):
156+
"""Test _check_or_generate_task_id uses custom ID generator when provided."""
157+
id_generator = Mock(spec=IDGenerator)
158+
id_generator.generate.return_value = 'custom-task-id'
159+
160+
context = RequestContext(
161+
request=mock_params, task_id_generator=id_generator
162+
)
163+
# The method is called during initialization
164+
165+
assert context.task_id == 'custom-task-id'
166+
152167
def test_check_or_generate_context_id_no_params(self):
153168
"""Test _check_or_generate_context_id with no params does nothing."""
154169
context = RequestContext()
@@ -168,6 +183,20 @@ def test_check_or_generate_context_id_with_existing_context_id(
168183
assert context.context_id == existing_id
169184
assert mock_params.message.context_id == existing_id
170185

186+
def test_check_or_generate_context_id_with_custom_id_generator(
187+
self, mock_params
188+
):
189+
"""Test _check_or_generate_context_id uses custom ID generator when provided."""
190+
id_generator = Mock(spec=IDGenerator)
191+
id_generator.generate.return_value = 'custom-context-id'
192+
193+
context = RequestContext(
194+
request=mock_params, context_id_generator=id_generator
195+
)
196+
# The method is called during initialization
197+
198+
assert context.context_id == 'custom-context-id'
199+
171200
def test_init_raises_error_on_task_id_mismatch(
172201
self, mock_params, mock_task
173202
):

tests/server/tasks/test_task_updater.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
22
import uuid
33

4-
from unittest.mock import AsyncMock, patch
4+
from unittest.mock import AsyncMock, Mock, patch
55

66
import pytest
77

88
from a2a.server.events import EventQueue
9+
from a2a.server.id_generator import IDGenerator
910
from a2a.server.tasks import TaskUpdater
1011
from a2a.types import (
1112
Message,
@@ -151,6 +152,26 @@ async def test_add_artifact_generates_id(
151152
assert event.last_chunk is None
152153

153154

155+
@pytest.mark.asyncio
156+
async def test_add_artifact_generates_custom_id(event_queue, sample_parts):
157+
"""Test add_artifact uses a custom ID generator when provided."""
158+
artifact_id_generator = Mock(spec=IDGenerator)
159+
artifact_id_generator.generate.return_value = 'custom-artifact-id'
160+
task_updater = TaskUpdater(
161+
event_queue=event_queue,
162+
task_id='test-task-id',
163+
context_id='test-context-id',
164+
artifact_id_generator=artifact_id_generator,
165+
)
166+
167+
await task_updater.add_artifact(parts=sample_parts, artifact_id=None)
168+
169+
event_queue.enqueue_event.assert_called_once()
170+
event = event_queue.enqueue_event.call_args[0][0]
171+
assert isinstance(event, TaskArtifactUpdateEvent)
172+
assert event.artifact.artifact_id == 'custom-artifact-id'
173+
174+
154175
@pytest.mark.asyncio
155176
@pytest.mark.parametrize(
156177
'append_val, last_chunk_val',
@@ -304,6 +325,22 @@ def test_new_agent_message_with_metadata(task_updater, sample_parts):
304325
assert message.metadata == metadata
305326

306327

328+
def test_new_agent_message_with_custom_id_generator(event_queue, sample_parts):
329+
"""Test creating a new agent message with a custom message ID generator."""
330+
message_id_generator = Mock(spec=IDGenerator)
331+
message_id_generator.generate.return_value = 'custom-message-id'
332+
task_updater = TaskUpdater(
333+
event_queue=event_queue,
334+
task_id='test-task-id',
335+
context_id='test-context-id',
336+
message_id_generator=message_id_generator,
337+
)
338+
339+
message = task_updater.new_agent_message(parts=sample_parts)
340+
341+
assert message.message_id == 'custom-message-id'
342+
343+
307344
@pytest.mark.asyncio
308345
async def test_failed_without_message(task_updater, event_queue):
309346
"""Test marking a task as failed without a message."""

0 commit comments

Comments
 (0)