Skip to content

Commit f75bf8b

Browse files
author
long.qul
committed
feat(server): implement RedisQueueManager for distributed event handling
- Add RedisQueueManager class to manage event queues across distributed services - Implement local and proxy queue management using Redis pub/sub - Add logging for better visibility and debugging - Update tests to cover new functionality
1 parent 2ad5ff3 commit f75bf8b

File tree

2 files changed

+177
-26
lines changed

2 files changed

+177
-26
lines changed

src/a2a/server/events/redis_queue_manager.py

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import asyncio
2+
import logging
23

34
from asyncio import Task
45
from functools import partial
6+
from typing import Any, Dict, Optional
57

8+
from pydantic import ValidationError, TypeAdapter
69
from redis.asyncio import Redis
710

811
from a2a.server.events import (
@@ -14,6 +17,8 @@
1417
TaskQueueExists,
1518
)
1619

20+
logger = logging.getLogger(__name__)
21+
1722

1823
class RedisQueueManager(QueueManager):
1924
"""This implements the `QueueManager` interface using Redis for event.
@@ -40,14 +45,7 @@ def __init__(
4045
self._relay_channel_name = relay_channel_key_prefix
4146
self._background_tasks: dict[str, Task] = {}
4247
self._task_registry_name = task_registry_key
43-
44-
async def _listen_and_relay(self, task_id: str) -> None:
45-
c = EventConsumer(self._local_queue[task_id])
46-
async for event in c.consume_all():
47-
await self._redis.publish(
48-
self._task_channel_name(task_id),
49-
event.model_dump_json(exclude_none=True),
50-
)
48+
self._pubsub_listener_task: Optional[Task] = None
5149

5250
def _task_channel_name(self, task_id: str) -> str:
5351
return self._relay_channel_name + task_id
@@ -57,10 +55,23 @@ async def _has_task_id(self, task_id: str) -> bool:
5755
return ret == 1
5856

5957
async def _register_task_id(self, task_id: str) -> None:
60-
await self._redis.sadd(self._task_registry_name, task_id)
58+
task_started_event = asyncio.Event()
59+
async def _wrapped_listen_and_relay() -> None:
60+
task_started_event.set()
61+
c = EventConsumer(self._local_queue[task_id].tap())
62+
async for event in c.consume_all():
63+
logger.debug(f'Publishing event for task {task_id} in QM {self}: {event}')
64+
await self._redis.publish(
65+
self._task_channel_name(task_id),
66+
event.model_dump_json(exclude_none=True),
67+
)
68+
6169
self._background_tasks[task_id] = asyncio.create_task(
62-
self._listen_and_relay(task_id)
70+
_wrapped_listen_and_relay()
6371
)
72+
await task_started_event.wait()
73+
await self._redis.sadd(self._task_registry_name, task_id)
74+
logger.debug(f'Started to listen and relay events for task {task_id}')
6475

6576
async def _remove_task_id(self, task_id: str) -> bool:
6677
if task_id in self._background_tasks:
@@ -70,21 +81,51 @@ async def _remove_task_id(self, task_id: str) -> bool:
7081
return await self._redis.srem(self._task_registry_name, task_id) == 1
7182

7283
async def _subscribe_remote_task_events(self, task_id: str) -> None:
73-
await self._pubsub.subscribe(
74-
**{
75-
self._task_channel_name(task_id): partial(
76-
self._relay_remote_events, task_id
77-
)
78-
}
79-
)
80-
81-
def _unsubscribe_remote_task_events(self, task_id: str) -> None:
82-
self._pubsub.unsubscribe(self._task_channel_name(task_id))
83-
84-
def _relay_remote_events(self, task_id: str, event_json: str) -> None:
85-
if task_id in self._proxy_queue:
86-
event = Event.model_validate_json(event_json)
87-
self._proxy_queue[task_id].enqueue_event(event)
84+
channel_id = self._task_channel_name(task_id)
85+
await self._pubsub.subscribe(**{channel_id: self._relay_remote_events})
86+
87+
# this is a global listener to handle incoming pubsub events
88+
if not self._pubsub_listener_task:
89+
logger.debug('Creating pubsub listener task.')
90+
self._pubsub_listener_task = asyncio.create_task(self._consume_pubsub_messages())
91+
92+
logger.debug(f"Subscribed for remote events for task {task_id}")
93+
94+
async def _consume_pubsub_messages(self):
95+
async for _ in self._pubsub.listen():
96+
pass
97+
98+
async def _relay_remote_events(self, subscription_event) -> None:
99+
if 'channel' not in subscription_event or 'data' not in subscription_event:
100+
logger.warning(f"channel or data is absent in subscription event: {subscription_event}")
101+
return
102+
103+
channel_id: str = subscription_event['channel'].decode('utf-8')
104+
data_string: str = subscription_event['data'].decode('utf-8')
105+
task_id = channel_id.split('.')[-1]
106+
if task_id not in self._proxy_queue:
107+
logger.warning(f"task_id {task_id} not found in proxy queue")
108+
return
109+
110+
try:
111+
logger.debug(f"Received event for task_id {task_id} in QM {self}: {data_string}")
112+
event = TypeAdapter(Event).validate_json(data_string)
113+
except Exception as e:
114+
logger.warning(f"Failed to parse event from subscription event: {subscription_event}: {e}")
115+
return
116+
117+
logger.debug(f"Enqueuing event for task_id {task_id} in QM {self}: {event}")
118+
await self._proxy_queue[task_id].enqueue_event(event)
119+
120+
121+
async def _unsubscribe_remote_task_events(self, task_id: str) -> None:
122+
# unsubscribe channel for given task_id
123+
await self._pubsub.unsubscribe(self._task_channel_name(task_id))
124+
# release global listener if not channel is subscribed
125+
async with self._lock:
126+
if not self._pubsub.subscribed and self._pubsub_listener_task:
127+
self._pubsub_listener_task.cancel()
128+
self._pubsub_listener_task = None
88129

89130
async def add(self, task_id: str, queue: EventQueue) -> None:
90131
"""Add a new local event queue for the specified task.
@@ -96,11 +137,13 @@ async def add(self, task_id: str, queue: EventQueue) -> None:
96137
Raises:
97138
TaskQueueExists: If a queue for the task already exists.
98139
"""
140+
logger.debug(f"add {task_id}")
99141
async with self._lock:
100142
if await self._has_task_id(task_id):
101143
raise TaskQueueExists()
102144
self._local_queue[task_id] = queue
103145
await self._register_task_id(task_id)
146+
logger.debug(f"Local queue is created for task {task_id}")
104147

105148
async def get(self, task_id: str) -> EventQueue | None:
106149
"""Get the event queue associated with the given task ID.
@@ -115,17 +158,22 @@ async def get(self, task_id: str) -> EventQueue | None:
115158
Returns:
116159
EventQueue | None: The event queue if found, otherwise None.
117160
"""
161+
logger.debug(f"get {task_id}")
118162
async with self._lock:
119163
# lookup locally
120164
if task_id in self._local_queue:
165+
logger.debug(f"Got local queue for task_id {task_id}")
121166
return self._local_queue[task_id]
122167
# lookup globally
123168
if await self._has_task_id(task_id):
124169
if task_id not in self._proxy_queue:
170+
logger.debug(f"Creating proxy queue for {task_id}")
125171
queue = EventQueue()
126172
self._proxy_queue[task_id] = queue
127173
await self._subscribe_remote_task_events(task_id)
174+
logger.debug(f"Got proxy queue for task_id {task_id}")
128175
return self._proxy_queue[task_id]
176+
logger.warning(f"Attempted to get non-existing queue for task {task_id}")
129177
return None
130178

131179
async def tap(self, task_id: str) -> EventQueue | None:
@@ -137,8 +185,10 @@ async def tap(self, task_id: str) -> EventQueue | None:
137185
Returns:
138186
EventQueue | None: A new reference to the event queue if it exists, otherwise None.
139187
"""
188+
logger.debug(f"tap {task_id}")
140189
event_queue = await self.get(task_id)
141190
if event_queue:
191+
logger.debug(f'Tapping event queue for task: {task_id}')
142192
return event_queue.tap()
143193
return None
144194

@@ -155,23 +205,27 @@ async def close(self, task_id: str) -> None:
155205
Raises:
156206
NoTaskQueue: If no queue exists for the given task ID.
157207
"""
208+
logger.debug(f"close {task_id}")
158209
async with self._lock:
159210
if task_id in self._local_queue:
160211
# close locally
161212
queue = self._local_queue.pop(task_id)
162213
await queue.close()
163214
# remove from global registry if a local queue is closed
164215
await self._remove_task_id(task_id)
216+
logger.debug(f"Closing local queue for task {task_id}")
165217
return
166218

167219
if task_id in self._proxy_queue:
168220
# close proxy queue
169221
queue = self._proxy_queue.pop(task_id)
170222
await queue.close()
171223
# unsubscribe from remote, but don't remove from global registry
172-
self._unsubscribe_remote_task_events(task_id)
224+
await self._unsubscribe_remote_task_events(task_id)
225+
logger.debug(f"Closing proxy queue for task {task_id}")
173226
return
174227

228+
logger.warning(f"Attempted to close non-existing queue found for task {task_id}")
175229
raise NoTaskQueue()
176230

177231
async def create_or_tap(self, task_id: str) -> EventQueue:
@@ -186,23 +240,28 @@ async def create_or_tap(self, task_id: str) -> EventQueue:
186240
Returns:
187241
EventQueue: An event queue associated with the given task ID.
188242
"""
243+
logger.debug(f"create_or_tap {task_id}")
189244
async with self._lock:
190245
if await self._has_task_id(task_id):
191246
# if it's a local queue, tap directly
192247
if task_id in self._local_queue:
248+
logger.debug(f"Tapping a local queue for task {task_id}")
193249
return self._local_queue[task_id].tap()
194250

195251
# if it's a proxy queue, tap the proxy
196252
if task_id in self._proxy_queue:
253+
logger.debug(f"Tapping a proxy queue for task {task_id}")
197254
return self._proxy_queue[task_id].tap()
198255

199256
# if the proxy is not created, create the proxy and return
200257
queue = EventQueue()
201258
self._proxy_queue[task_id] = queue
202259
await self._subscribe_remote_task_events(task_id)
260+
logger.debug(f"Creating a proxy queue for task {task_id}")
203261
return self._proxy_queue[task_id]
204262
# the task doesn't exist before, create a local queue
205263
queue = EventQueue()
206264
self._local_queue[task_id] = queue
207265
await self._register_task_id(task_id)
266+
logger.debug(f"Creating a local queue for task {task_id}")
208267
return queue

tests/server/test_integration.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import asyncio
2+
import logging
3+
import os
24

35
from typing import Any
46
from unittest import mock
57

68
import pytest
9+
import pytest_asyncio
10+
from redis.asyncio import Redis
711

812
from starlette.authentication import (
913
AuthCredentials,
@@ -22,6 +26,8 @@
2226
A2AFastAPIApplication,
2327
A2AStarletteApplication,
2428
)
29+
from a2a.server.events import EventQueue
30+
from a2a.server.events.redis_queue_manager import RedisQueueManager
2531
from a2a.types import (
2632
AgentCapabilities,
2733
AgentCard,
@@ -44,6 +50,7 @@
4450
TextPart,
4551
UnsupportedOperationError,
4652
)
53+
from a2a.utils import new_agent_text_message
4754
from a2a.utils.errors import MethodNotImplementedError
4855

4956

@@ -884,3 +891,88 @@ def test_non_dict_json(client: TestClient):
884891
data = response.json()
885892
assert 'error' in data
886893
assert data['error']['code'] == InvalidRequestError().code
894+
895+
896+
# === RedisQueueManager ===
897+
@pytest.mark.asyncio
898+
@pytest_asyncio.fixture(scope="function")
899+
async def asyncio_redis():
900+
redis_server_url = os.getenv("REDIS_SERVER_URL")
901+
if redis_server_url:
902+
redis = Redis.from_url(redis_server_url)
903+
else:
904+
# use fake redis instead if no redis server url is given
905+
from fakeredis import FakeAsyncRedis
906+
redis = FakeAsyncRedis()
907+
logging.info("flush redis for next test case")
908+
await redis.flushall(asynchronous=False)
909+
yield redis
910+
await redis.close()
911+
912+
913+
@pytest.mark.asyncio
914+
async def test_redis_queue_local_only_queue(asyncio_redis):
915+
queue_manager = RedisQueueManager(asyncio_redis)
916+
917+
# setup local queues
918+
q1 = EventQueue()
919+
await queue_manager.add('task_1', q1)
920+
q2 = EventQueue()
921+
await queue_manager.add('task_2', q2)
922+
q3 = await queue_manager.tap("task_1")
923+
assert await queue_manager.get('task_1') == q1
924+
assert await queue_manager.get('task_2') == q2
925+
926+
# send and receive locally
927+
msg1 = new_agent_text_message('hello')
928+
await q1.enqueue_event(msg1)
929+
assert await q1.dequeue_event(no_wait=True) == msg1
930+
assert await q3.dequeue_event(no_wait=True) == msg1
931+
# raise error if queue is empty
932+
with pytest.raises(asyncio.QueueEmpty):
933+
await q1.dequeue_event(no_wait=True)
934+
# q2 is empty
935+
with pytest.raises(asyncio.QueueEmpty):
936+
await q2.dequeue_event(no_wait=True)
937+
938+
939+
@pytest.mark.asyncio
940+
async def test_redis_queue_mixed_queue(asyncio_redis):
941+
qm1 = RedisQueueManager(asyncio_redis)
942+
qm2 = RedisQueueManager(asyncio_redis)
943+
qm3 = RedisQueueManager(asyncio_redis)
944+
945+
# create local queue in qm1
946+
q1 = EventQueue()
947+
await qm1.add('task_1', q1)
948+
assert 'task_1' in qm1._local_queue
949+
assert await qm1.get('task_1') == q1
950+
951+
# create proxy queue in qm2 through `get` method
952+
q1_1 = await qm2.get('task_1')
953+
assert 'task_1' in qm2._proxy_queue and 'task_1' not in qm2._local_queue
954+
assert q1_1 != q1
955+
956+
# create proxy queue in qm3 through `tap` method
957+
q1_2 = await qm3.tap("task_1")
958+
assert 'task_1' in qm3._proxy_queue and 'task_1' not in qm3._local_queue
959+
960+
# enqueue and dequeue in q1
961+
msg1 = new_agent_text_message('hello')
962+
await q1.enqueue_event(msg1)
963+
assert await q1.dequeue_event() == msg1
964+
with pytest.raises(asyncio.QueueEmpty):
965+
await q1.dequeue_event(no_wait=True)
966+
967+
# dequeue in q1_1
968+
msg1_1: Message = await q1_1.dequeue_event()
969+
assert msg1_1.parts[0].root.text == msg1.parts[0].root.text
970+
with pytest.raises(asyncio.QueueEmpty):
971+
await q1_1.dequeue_event(no_wait=True)
972+
973+
# dequeue in q1_2
974+
msg1_2: Message = await q1_2.dequeue_event()
975+
assert msg1_2.parts[0].root.text == msg1.parts[0].root.text
976+
with pytest.raises(asyncio.QueueEmpty):
977+
await q1_2.dequeue_event(no_wait=True)
978+

0 commit comments

Comments
 (0)