Skip to content

Commit cbbb7bb

Browse files
pstephengoogleGerrit Code Review
authored andcommitted
Merge "Add QueueManager abstraction and implement in memory version." into main
2 parents a12333e + 0356855 commit cbbb7bb

File tree

4 files changed

+137
-31
lines changed

4 files changed

+137
-31
lines changed

src/a2a/server/events/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from a2a.server.events.event_consumer import EventConsumer
22
from a2a.server.events.event_queue import Event, EventQueue
3+
from a2a.server.events.queue_manager import QueueManager, TaskQueueExists, NoTaskQueue
4+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
35

4-
5-
__all__ = ['Event', 'EventConsumer', 'EventQueue']
6+
__all__ = [
7+
'Event',
8+
'EventConsumer',
9+
'EventQueue',
10+
'QueueManager',
11+
'TaskQueueExists',
12+
'NoTaskQueue',
13+
'InMemoryQueueManager',
14+
]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import asyncio
2+
from a2a.server.events.event_queue import EventQueue
3+
from a2a.server.events.queue_manager import QueueManager, TaskQueueExists, NoTaskQueue
4+
5+
class InMemoryQueueManager(QueueManager):
6+
"""InMemoryQueueManager is used for a single binary management.
7+
8+
This implements the QueueManager but requires all incoming interactions
9+
to hit the same binary that manages the queues.
10+
11+
This works for single binary solution. Needs a distributed approach for
12+
true scalable deployment.
13+
"""
14+
15+
def __init__(self):
16+
self._task_queue: dict[str, EventQueue] = {}
17+
self._lock = asyncio.Lock()
18+
19+
async def add(self, task_id: str, queue: EventQueue):
20+
async with self._lock:
21+
if task_id in self._task_queue:
22+
raise TaskQueueExists()
23+
self._task_queue[task_id] = queue
24+
25+
async def get(self, task_id: str) -> EventQueue | None:
26+
async with self._lock:
27+
if task_id not in self._task_queue:
28+
return None
29+
return self._task_queue[task_id]
30+
31+
async def tap(self, task_id: str) -> EventQueue | None:
32+
async with self._lock:
33+
if task_id not in self._task_queue:
34+
return None
35+
return self._task_queue[task_id].tap()
36+
37+
async def close(self, task_id: str):
38+
async with self._lock:
39+
if task_id not in self._task_queue:
40+
raise NoTaskQueue()
41+
del self._task_queue[task_id]
42+
43+
async def create_or_tap(self, task_id: str) -> EventQueue:
44+
async with self._lock:
45+
if task_id not in self._task_queue:
46+
queue = EventQueue()
47+
self._task_queue[task_id] = queue
48+
return queue
49+
return self._task_queue[task_id].tap()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from abc import ABC, abstractmethod
2+
from a2a.server.events.event_queue import EventQueue
3+
4+
5+
class QueueManager(ABC):
6+
"""Interface for managing the event queue lifecycles."""
7+
8+
@abstractmethod
9+
async def add(self, task_id: str, queue: EventQueue):
10+
pass
11+
12+
@abstractmethod
13+
async def get(self, task_id: str) -> EventQueue | None:
14+
pass
15+
16+
@abstractmethod
17+
async def tap(self, task_id: str) -> EventQueue | None:
18+
pass
19+
20+
@abstractmethod
21+
async def close(self, task_id: str):
22+
pass
23+
24+
@abstractmethod
25+
async def create_or_tap(self, task_id: str) -> EventQueue:
26+
pass
27+
28+
29+
class TaskQueueExists(Exception):
30+
pass
31+
32+
class NoTaskQueue(Exception):
33+
pass

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,15 @@
44
from collections.abc import AsyncGenerator
55

66
from a2a.server.agent_execution import AgentExecutor, RequestContext
7-
from a2a.server.events import Event, EventConsumer, EventQueue
7+
from a2a.server.events import (
8+
EventConsumer,
9+
EventQueue,
10+
Event,
11+
QueueManager,
12+
TaskQueueExists,
13+
NoTaskQueue,
14+
InMemoryQueueManager,
15+
)
816
from a2a.server.request_handlers.request_handler import RequestHandler
917
from a2a.server.tasks import ResultAggregator, TaskManager, TaskStore
1018
from a2a.types import (
@@ -28,13 +36,14 @@ class DefaultRequestHandler(RequestHandler):
2836
"""Default request handler for all incoming requests."""
2937

3038
def __init__(
31-
self, agent_executor: AgentExecutor, task_store: TaskStore
39+
self,
40+
agent_executor: AgentExecutor,
41+
task_store: TaskStore,
42+
queue_manager: QueueManager = InMemoryQueueManager(),
3243
) -> None:
3344
self.agent_executor = agent_executor
3445
self.task_store = task_store
35-
# This works for single binary solution. Needs a distributed approach for
36-
# true scalable deployment.
37-
self._task_queue: dict[str, EventQueue] = {}
46+
self._queue_manager = queue_manager
3847

3948
async def on_get_task(self, params: TaskQueryParams) -> Task | None:
4049
"""Default handler for 'tasks/get'."""
@@ -56,8 +65,10 @@ async def on_cancel_task(self, params: TaskIdParams) -> Task | None:
5665
initial_message=None,
5766
)
5867
result_aggregator = ResultAggregator(task_manager)
59-
60-
queue = EventQueue()
68+
try:
69+
queue = await self._queue_manager.tap(task.id)
70+
except:
71+
queue = EventQueue()
6172
await self.agent_executor.cancel(
6273
RequestContext(
6374
None,
@@ -95,15 +106,11 @@ async def on_message_send(
95106
task: Task | None = await task_manager.get_task()
96107
if task:
97108
task = task_manager.update_with_message(params.message, task)
109+
queue = await self._queue_manager.create_or_tap(task.id)
110+
else:
111+
queue = EventQueue()
98112
result_aggregator = ResultAggregator(task_manager)
99113
# TODO to manage the non-blocking flows.
100-
101-
queue = EventQueue()
102-
# If this is a follow up on an existing task, register the queue now
103-
task_id: str | None = task.id if task else None
104-
if task_id:
105-
self._task_queue[task_id] = queue
106-
107114
producer_task = asyncio.create_task(
108115
self._run_event_stream(
109116
RequestContext(
@@ -127,6 +134,11 @@ async def on_message_send(
127134
return result
128135
finally:
129136
await producer_task
137+
if task:
138+
try:
139+
await self._queue_manager.close(task.id)
140+
except NoTaskQueue:
141+
pass
130142

131143
async def on_message_send_stream(
132144
self, params: MessageSendParams
@@ -142,14 +154,11 @@ async def on_message_send_stream(
142154

143155
if task:
144156
task = task_manager.update_with_message(params.message, task)
145-
157+
queue = await self._queue_manager.create_or_tap(task.id)
158+
else:
159+
queue = EventQueue()
146160
result_aggregator = ResultAggregator(task_manager)
147-
queue = EventQueue()
148-
149161
task_id: str | None = task.id if task else None
150-
if task_id:
151-
self._task_queue[task_id] = queue
152-
153162
producer_task = asyncio.create_task(
154163
self._run_event_stream(
155164
RequestContext(
@@ -165,13 +174,20 @@ async def on_message_send_stream(
165174
consumer = EventConsumer(queue)
166175
async for event in result_aggregator.consume_and_emit(consumer):
167176
# Now we know we have a Task, register the queue
168-
if isinstance(event, Task) and event.id not in self._task_queue:
169-
self._task_queue[event.id] = queue
170-
task_id = event.id
171-
yield event
172-
177+
if isinstance(event, Task):
178+
try:
179+
await self._queue_manager.add(event.id, queue)
180+
task_id = event.id
181+
except TaskQueueExists:
182+
logging.info(
183+
'Multiple Task objects created in event stream.')
184+
yield event
173185
finally:
174186
await producer_task
187+
try:
188+
await self._queue_manager.close(task_id)
189+
except NoTaskQueue:
190+
pass
175191

176192
async def on_set_task_push_notification_config(
177193
self, request: TaskPushNotificationConfig
@@ -203,12 +219,11 @@ async def on_resubscribe_to_task(
203219

204220
result_aggregator = ResultAggregator(task_manager)
205221

206-
# Need to tap the existing queue.
207-
if not task.id in self._task_queue:
222+
try:
223+
queue = await self._queue_manager.tap(task.id)
224+
except NoTaskQueue:
208225
raise ServerError(error=TaskNotFoundError())
209-
return
210226

211-
queue = self._task_queue[task.id].tap()
212227
consumer = EventConsumer(queue)
213228
async for event in result_aggregator.consume_and_emit(consumer):
214229
yield event

0 commit comments

Comments
 (0)