Skip to content

Commit d7f3ffd

Browse files
Add ServerCallContext into task store operations
1 parent 9193208 commit d7f3ffd

File tree

5 files changed

+54
-20
lines changed

5 files changed

+54
-20
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ async def on_get_task(
109109
context: ServerCallContext | None = None,
110110
) -> Task | None:
111111
"""Default handler for 'tasks/get'."""
112-
task: Task | None = await self.task_store.get(params.id)
112+
task: Task | None = await self.task_store.get(params.id, context)
113113
if not task:
114114
raise ServerError(error=TaskNotFoundError())
115115

@@ -141,7 +141,7 @@ async def on_cancel_task(
141141
142142
Attempts to cancel the task managed by the `AgentExecutor`.
143143
"""
144-
task: Task | None = await self.task_store.get(params.id)
144+
task: Task | None = await self.task_store.get(params.id, context)
145145
if not task:
146146
raise ServerError(error=TaskNotFoundError())
147147

@@ -158,6 +158,7 @@ async def on_cancel_task(
158158
context_id=task.context_id,
159159
task_store=self.task_store,
160160
initial_message=None,
161+
context=context,
161162
)
162163
result_aggregator = ResultAggregator(task_manager)
163164

@@ -217,6 +218,7 @@ async def _setup_message_execution(
217218
context_id=params.message.context_id,
218219
task_store=self.task_store,
219220
initial_message=params.message,
221+
context=context,
220222
)
221223
task: Task | None = await task_manager.get_task()
222224

@@ -417,7 +419,7 @@ async def on_set_task_push_notification_config(
417419
if not self._push_config_store:
418420
raise ServerError(error=UnsupportedOperationError())
419421

420-
task: Task | None = await self.task_store.get(params.task_id)
422+
task: Task | None = await self.task_store.get(params.task_id, context)
421423
if not task:
422424
raise ServerError(error=TaskNotFoundError())
423425

@@ -440,7 +442,7 @@ async def on_get_task_push_notification_config(
440442
if not self._push_config_store:
441443
raise ServerError(error=UnsupportedOperationError())
442444

443-
task: Task | None = await self.task_store.get(params.id)
445+
task: Task | None = await self.task_store.get(params.id, context)
444446
if not task:
445447
raise ServerError(error=TaskNotFoundError())
446448

@@ -469,7 +471,7 @@ async def on_resubscribe_to_task(
469471
Allows a client to re-attach to a running streaming task's event stream.
470472
Requires the task and its queue to still be active.
471473
"""
472-
task: Task | None = await self.task_store.get(params.id)
474+
task: Task | None = await self.task_store.get(params.id, context)
473475
if not task:
474476
raise ServerError(error=TaskNotFoundError())
475477

@@ -485,6 +487,7 @@ async def on_resubscribe_to_task(
485487
context_id=task.context_id,
486488
task_store=self.task_store,
487489
initial_message=None,
490+
context=context,
488491
)
489492

490493
result_aggregator = ResultAggregator(task_manager)
@@ -509,7 +512,7 @@ async def on_list_task_push_notification_config(
509512
if not self._push_config_store:
510513
raise ServerError(error=UnsupportedOperationError())
511514

512-
task: Task | None = await self.task_store.get(params.id)
515+
task: Task | None = await self.task_store.get(params.id, context)
513516
if not task:
514517
raise ServerError(error=TaskNotFoundError())
515518

@@ -536,7 +539,7 @@ async def on_delete_task_push_notification_config(
536539
if not self._push_config_store:
537540
raise ServerError(error=UnsupportedOperationError())
538541

539-
task: Task | None = await self.task_store.get(params.id)
542+
task: Task | None = await self.task_store.get(params.id, context)
540543
if not task:
541544
raise ServerError(error=TaskNotFoundError())
542545

src/a2a/server/tasks/database_task_store.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"or 'pip install a2a-sdk[sql]'"
2020
) from e
2121

22+
from a2a.server.context import ServerCallContext
2223
from a2a.server.models import Base, TaskModel, create_task_model
2324
from a2a.server.tasks.task_store import TaskStore
2425
from a2a.types import Task # Task is the Pydantic model
@@ -119,15 +120,19 @@ def _from_orm(self, task_model: TaskModel) -> Task:
119120
# Pydantic's model_validate will parse the nested dicts/lists from JSON
120121
return Task.model_validate(task_data_from_db)
121122

122-
async def save(self, task: Task) -> None:
123+
async def save(
124+
self, task: Task, context: ServerCallContext | None = None
125+
) -> None:
123126
"""Saves or updates a task in the database."""
124127
await self._ensure_initialized()
125128
db_task = self._to_orm(task)
126129
async with self.async_session_maker.begin() as session:
127130
await session.merge(db_task)
128131
logger.debug('Task %s saved/updated successfully.', task.id)
129132

130-
async def get(self, task_id: str) -> Task | None:
133+
async def get(
134+
self, task_id: str, context: ServerCallContext | None = None
135+
) -> Task | None:
131136
"""Retrieves a task from the database by ID."""
132137
await self._ensure_initialized()
133138
async with self.async_session_maker() as session:
@@ -142,7 +147,9 @@ async def get(self, task_id: str) -> Task | None:
142147
logger.debug('Task %s not found in store.', task_id)
143148
return None
144149

145-
async def delete(self, task_id: str) -> None:
150+
async def delete(
151+
self, task_id: str, context: ServerCallContext | None = None
152+
) -> None:
146153
"""Deletes a task from the database by ID."""
147154
await self._ensure_initialized()
148155

src/a2a/server/tasks/inmemory_task_store.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import logging
33

4+
from a2a.server.context import ServerCallContext
45
from a2a.server.tasks.task_store import TaskStore
56
from a2a.types import Task
67

@@ -21,13 +22,17 @@ def __init__(self) -> None:
2122
self.tasks: dict[str, Task] = {}
2223
self.lock = asyncio.Lock()
2324

24-
async def save(self, task: Task) -> None:
25+
async def save(
26+
self, task: Task, context: ServerCallContext | None = None
27+
) -> None:
2528
"""Saves or updates a task in the in-memory store."""
2629
async with self.lock:
2730
self.tasks[task.id] = task
2831
logger.debug('Task %s saved successfully.', task.id)
2932

30-
async def get(self, task_id: str) -> Task | None:
33+
async def get(
34+
self, task_id: str, context: ServerCallContext | None = None
35+
) -> Task | None:
3136
"""Retrieves a task from the in-memory store by ID."""
3237
async with self.lock:
3338
logger.debug('Attempting to get task with id: %s', task_id)
@@ -38,7 +43,9 @@ async def get(self, task_id: str) -> Task | None:
3843
logger.debug('Task %s not found in store.', task_id)
3944
return task
4045

41-
async def delete(self, task_id: str) -> None:
46+
async def delete(
47+
self, task_id: str, context: ServerCallContext | None = None
48+
) -> None:
4249
"""Deletes a task from the in-memory store by ID."""
4350
async with self.lock:
4451
logger.debug('Attempting to delete task with id: %s', task_id)

src/a2a/server/tasks/task_manager.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
from a2a.server.context import ServerCallContext
34
from a2a.server.events.event_queue import Event
45
from a2a.server.tasks.task_store import TaskStore
56
from a2a.types import (
@@ -31,6 +32,7 @@ def __init__(
3132
context_id: str | None,
3233
task_store: TaskStore,
3334
initial_message: Message | None,
35+
context: ServerCallContext | None = None,
3436
):
3537
"""Initializes the TaskManager.
3638
@@ -49,6 +51,7 @@ def __init__(
4951
self.task_store = task_store
5052
self._initial_message = initial_message
5153
self._current_task: Task | None = None
54+
self._call_context: ServerCallContext | None = context
5255
logger.debug(
5356
'TaskManager initialized with task_id: %s, context_id: %s',
5457
task_id,
@@ -74,7 +77,9 @@ async def get_task(self) -> Task | None:
7477
logger.debug(
7578
'Attempting to get task from store with id: %s', self.task_id
7679
)
77-
self._current_task = await self.task_store.get(self.task_id)
80+
self._current_task = await self.task_store.get(
81+
self.task_id, self._context
82+
)
7883
if self._current_task:
7984
logger.debug('Task %s retrieved successfully.', self.task_id)
8085
else:
@@ -167,7 +172,7 @@ async def ensure_task(
167172
logger.debug(
168173
'Attempting to retrieve existing task with id: %s', self.task_id
169174
)
170-
task = await self.task_store.get(self.task_id)
175+
task = await self.task_store.get(self.task_id, self._context)
171176

172177
if not task:
173178
logger.info(
@@ -231,7 +236,7 @@ async def _save_task(self, task: Task) -> None:
231236
task: The `Task` object to save.
232237
"""
233238
logger.debug('Saving task with id: %s', task.id)
234-
await self.task_store.save(task)
239+
await self.task_store.save(task, self._context)
235240
self._current_task = task
236241
if not self.task_id:
237242
logger.info('New task created with id: %s', task.id)

src/a2a/server/tasks/task_store.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
11
from abc import ABC, abstractmethod
22

3+
from a2a.server.context import ServerCallContext
34
from a2a.types import Task
45

5-
66
class TaskStore(ABC):
77
"""Agent Task Store interface.
88
99
Defines the methods for persisting and retrieving `Task` objects.
1010
"""
1111

1212
@abstractmethod
13-
async def save(self, task: Task) -> None:
13+
async def save(
14+
self,
15+
task: Task,
16+
context: ServerCallContext | None = None
17+
) -> None:
1418
"""Saves or updates a task in the store."""
1519

1620
@abstractmethod
17-
async def get(self, task_id: str) -> Task | None:
21+
async def get(
22+
self,
23+
task_id: str,
24+
context: ServerCallContext | None = None
25+
) -> Task | None:
1826
"""Retrieves a task from the store by ID."""
1927

2028
@abstractmethod
21-
async def delete(self, task_id: str) -> None:
29+
async def delete(
30+
self,
31+
task_id: str,
32+
context: ServerCallContext | None = None
33+
) -> None:
2234
"""Deletes a task from the store by ID."""

0 commit comments

Comments
 (0)