Skip to content

Commit ce19f18

Browse files
author
long.qul
committed
feat(server): implement Redis-based event queue manager
- Add RedisQueueManager class to handle event queues using Redis - Implement core functionalities: add, get, tap, close, create_or_tap - Add unit tests for RedisQueueManager - Update .gitignore to exclude spec.json and .idea - Refactor event type definition in event_queue.py - Add fakeredis dependency in pyproject.toml
1 parent 7c46e70 commit ce19f18

File tree

5 files changed

+268
-2
lines changed

5 files changed

+268
-2
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ __pycache__
88
.venv
99
coverage.xml
1010
.nox
11-
spec.json
11+
spec.json
12+
.idea

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies = [
2121
"grpcio-tools>=1.60",
2222
"grpcio_reflection>=1.7.0",
2323
"protobuf==5.29.5",
24+
"fakeredis>=2.30.1",
2425
]
2526

2627
classifiers = [

src/a2a/server/events/event_queue.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import asyncio
22
import logging
33
import sys
4+
from typing import Union, Annotated
5+
6+
from pydantic import Field
47

58
from a2a.types import (
69
Message,
@@ -14,7 +17,7 @@
1417
logger = logging.getLogger(__name__)
1518

1619

17-
Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
20+
Event = Annotated[Union[Message, Task ,TaskStatusUpdateEvent, TaskArtifactUpdateEvent], Field(discriminator="kind")]
1821
"""Type alias for events that can be enqueued."""
1922

2023
DEFAULT_MAX_QUEUE_SIZE = 1024
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import asyncio
2+
from asyncio import Task
3+
from functools import partial
4+
from typing import Dict
5+
6+
from redis.asyncio import Redis
7+
8+
from a2a.server.events import QueueManager, EventQueue, TaskQueueExists, Event, EventConsumer, NoTaskQueue
9+
10+
11+
class RedisQueueManager(QueueManager):
12+
"""
13+
This implements the `QueueManager` interface using Redis for event
14+
queues. Primary jobs:
15+
1. Broadcast local events to proxy queues in other processes using redis pubsub
16+
2. Subscribe event messages from redis pubsub and replay to local proxy queues
17+
"""
18+
19+
def __init__(self, redis_client: Redis,
20+
relay_channel_key_prefix: str = "a2a.event.relay.",
21+
task_registry_key: str = "a2a.event.registry"
22+
):
23+
self._redis = redis_client
24+
self._local_queue: dict[str, EventQueue] = {}
25+
self._proxy_queue: dict[str, EventQueue] = {}
26+
self._lock = asyncio.Lock()
27+
self._pubsub = redis_client.pubsub()
28+
self._relay_channel_name = relay_channel_key_prefix
29+
self._background_tasks: Dict[str, Task] = {}
30+
self._task_registry_name = task_registry_key
31+
32+
async def _listen_and_relay(self, task_id: str):
33+
c = EventConsumer(self._local_queue[task_id])
34+
async for event in c.consume_all():
35+
await self._redis.publish(self._task_channel_name(task_id), event.model_dump_json(exclude_none=True))
36+
37+
def _task_channel_name(self, task_id: str):
38+
return self._relay_channel_name + task_id
39+
40+
async def _has_task_id(self, task_id: str):
41+
ret = await self._redis.sismember(self._task_registry_name, task_id)
42+
return ret
43+
44+
async def _register_task_id(self, task_id: str):
45+
await self._redis.sadd(self._task_registry_name, task_id)
46+
self._background_tasks[task_id] = asyncio.create_task(self._listen_and_relay(task_id))
47+
48+
async def _remove_task_id(self, task_id: str):
49+
if task_id in self._background_tasks:
50+
self._background_tasks[task_id].cancel("task_id is closed: " + task_id)
51+
return await self._redis.srem(self._task_registry_name, task_id)
52+
53+
async def _subscribe_remote_task_events(self, task_id: str):
54+
await self._pubsub.subscribe(**{
55+
self._task_channel_name(task_id): partial(self._relay_remote_events, task_id)
56+
})
57+
58+
def _unsubscribe_remote_task_events(self, task_id: str):
59+
self._pubsub.unsubscribe(self._task_channel_name(task_id))
60+
61+
def _relay_remote_events(self, task_id: str , event_json: str):
62+
if task_id in self._proxy_queue:
63+
event = Event.model_validate_json(event_json)
64+
self._proxy_queue[task_id].enqueue_event(event)
65+
66+
async def add(self, task_id: str, queue: EventQueue) -> None:
67+
async with self._lock:
68+
if await self._has_task_id(task_id):
69+
raise TaskQueueExists()
70+
self._local_queue[task_id] = queue
71+
await self._register_task_id(task_id)
72+
73+
async def get(self, task_id: str) -> EventQueue | None:
74+
async with self._lock:
75+
# lookup locally
76+
if task_id in self._local_queue:
77+
return self._local_queue[task_id]
78+
# lookup globally
79+
if await self._has_task_id(task_id):
80+
if task_id not in self._proxy_queue:
81+
queue = EventQueue()
82+
self._proxy_queue[task_id] = queue
83+
await self._subscribe_remote_task_events(task_id)
84+
return self._proxy_queue[task_id]
85+
return None
86+
87+
async def tap(self, task_id: str) -> EventQueue | None:
88+
event_queue = await self.get(task_id)
89+
if event_queue:
90+
return event_queue.tap()
91+
return None
92+
93+
async def close(self, task_id: str) -> None:
94+
async with self._lock:
95+
if task_id in self._local_queue:
96+
# close locally
97+
queue = self._local_queue.pop(task_id)
98+
await queue.close()
99+
# remove from global registry if a local queue is closed
100+
await self._remove_task_id(task_id)
101+
return None
102+
103+
if task_id in self._proxy_queue:
104+
# close proxy queue
105+
queue = self._proxy_queue.pop(task_id)
106+
await queue.close()
107+
# unsubscribe from remote, but don't remove from global registry
108+
self._unsubscribe_remote_task_events(task_id)
109+
return None
110+
111+
raise NoTaskQueue()
112+
113+
async def create_or_tap(self, task_id: str) -> EventQueue:
114+
async with self._lock:
115+
if await self._has_task_id(task_id):
116+
# if it's a local queue, tap directly
117+
if task_id in self._local_queue:
118+
return self._local_queue[task_id].tap()
119+
120+
# if it's a proxy queue, tap the proxy
121+
if task_id in self._proxy_queue:
122+
return self._proxy_queue[task_id].tap()
123+
124+
# if the proxy is not created, create the proxy and return
125+
queue = EventQueue()
126+
self._proxy_queue[task_id] = queue
127+
await self._subscribe_remote_task_events(task_id)
128+
return self._proxy_queue[task_id]
129+
else:
130+
# the task doesn't exist before, create a local queue
131+
queue = EventQueue()
132+
self._local_queue[task_id] = queue
133+
await self._register_task_id(task_id)
134+
return queue
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import asyncio
2+
from unittest.mock import MagicMock
3+
4+
import pytest
5+
from fakeredis import FakeAsyncRedis
6+
7+
from a2a.server.events import EventQueue, TaskQueueExists
8+
from a2a.server.events.redis_queue_manager import RedisQueueManager
9+
10+
11+
class TestRedisQueueManager:
12+
@pytest.fixture
13+
def redis(self):
14+
return FakeAsyncRedis()
15+
16+
@pytest.fixture
17+
def queue_manager(self, redis):
18+
return RedisQueueManager(redis)
19+
20+
@pytest.fixture
21+
def event_queue(self):
22+
queue = MagicMock(spec=EventQueue)
23+
# Mock the tap method to return itself
24+
queue.tap.return_value = queue
25+
return queue
26+
27+
@pytest.mark.asyncio
28+
async def test_init(self, queue_manager):
29+
assert queue_manager._local_queue == {}
30+
assert queue_manager._proxy_queue == {}
31+
assert isinstance(queue_manager._lock, asyncio.Lock)
32+
33+
34+
@pytest.mark.asyncio
35+
async def test_add_new_queue(self, queue_manager, event_queue):
36+
"""Test adding a new queue to the manager."""
37+
task_id = 'test_task_id'
38+
await queue_manager.add(task_id, event_queue)
39+
assert queue_manager._local_queue[task_id] == event_queue
40+
41+
@pytest.mark.asyncio
42+
async def test_add_existing_queue(self, queue_manager, event_queue):
43+
task_id = 'test_task_id'
44+
await queue_manager.add(task_id, event_queue)
45+
46+
with pytest.raises(TaskQueueExists):
47+
await queue_manager.add(task_id, event_queue)
48+
49+
@pytest.mark.asyncio
50+
async def test_get_existing_queue(self, queue_manager, event_queue):
51+
task_id = 'test_task_id'
52+
await queue_manager.add(task_id, event_queue)
53+
54+
result = await queue_manager.get(task_id)
55+
assert result == event_queue
56+
57+
@pytest.mark.asyncio
58+
async def test_get_nonexistent_queue(self, queue_manager):
59+
result = await queue_manager.get('nonexistent_task_id')
60+
assert result is None
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_tap_existing_queue(self, queue_manager, event_queue):
65+
task_id = 'test_task_id'
66+
await queue_manager.add(task_id, event_queue)
67+
68+
result = await queue_manager.tap(task_id)
69+
assert result == event_queue
70+
event_queue.tap.assert_called_once()
71+
72+
@pytest.mark.asyncio
73+
async def test_tap_nonexistent_queue(self, queue_manager):
74+
result = await queue_manager.tap('nonexistent_task_id')
75+
assert result is None
76+
77+
@pytest.mark.asyncio
78+
async def test_close_existing_queue(self, queue_manager, event_queue):
79+
task_id = 'test_task_id'
80+
await queue_manager.add(task_id, event_queue)
81+
82+
await queue_manager.close(task_id)
83+
assert task_id not in queue_manager._local_queue
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_create_or_tap_existing_queue(
88+
self, queue_manager, event_queue
89+
):
90+
task_id = 'test_task_id'
91+
await queue_manager.add(task_id, event_queue)
92+
93+
result = await queue_manager.create_or_tap(task_id)
94+
95+
assert result == event_queue
96+
event_queue.tap.assert_called_once()
97+
98+
@pytest.mark.asyncio
99+
async def test_concurrency(self, queue_manager):
100+
async def add_task(task_id):
101+
queue = EventQueue()
102+
await queue_manager.add(task_id, queue)
103+
return task_id
104+
105+
async def get_task(task_id):
106+
return await queue_manager.get(task_id)
107+
108+
# Create 10 different task IDs
109+
task_ids = [f'task_{i}' for i in range(10)]
110+
111+
# Add tasks concurrently
112+
add_tasks = [add_task(task_id) for task_id in task_ids]
113+
added_task_ids = await asyncio.gather(*add_tasks)
114+
115+
# Verify all tasks were added
116+
assert set(added_task_ids) == set(task_ids)
117+
118+
# Get tasks concurrently
119+
get_tasks = [get_task(task_id) for task_id in task_ids]
120+
queues = await asyncio.gather(*get_tasks)
121+
122+
# Verify all queues are not None
123+
assert all(queue is not None for queue in queues)
124+
125+
# Verify all tasks are in the manager
126+
for task_id in task_ids:
127+
assert task_id in queue_manager._local_queue

0 commit comments

Comments
 (0)