Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ __pycache__
.venv
coverage.xml
.nox
spec.json
spec.json
.idea
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"grpcio-tools>=1.60",
"grpcio_reflection>=1.7.0",
"protobuf==5.29.5",
"fakeredis>=2.30.1",
]

classifiers = [
Expand Down
9 changes: 8 additions & 1 deletion src/a2a/server/events/event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import logging
import sys

from typing import Annotated

from pydantic import Field

from a2a.types import (
Message,
Task,
Expand All @@ -14,7 +18,10 @@
logger = logging.getLogger(__name__)


Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
Event = Annotated[
Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
Field(discriminator='kind'),
]
"""Type alias for events that can be enqueued."""

DEFAULT_MAX_QUEUE_SIZE = 1024
Expand Down
152 changes: 152 additions & 0 deletions src/a2a/server/events/redis_queue_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio

from asyncio import Task
from functools import partial

from redis.asyncio import Redis

from a2a.server.events import (
Event,
EventConsumer,
EventQueue,
NoTaskQueue,
QueueManager,
TaskQueueExists,
)


class RedisQueueManager(QueueManager):
"""This implements the `QueueManager` interface using Redis for event
queues. Primary jobs:
1. Broadcast local events to proxy queues in other processes using redis pubsub
2. Subscribe event messages from redis pubsub and replay to local proxy queues
"""

def __init__(
self,
redis_client: Redis,
relay_channel_key_prefix: str = 'a2a.event.relay.',
task_registry_key: str = 'a2a.event.registry',
):
self._redis = redis_client
self._local_queue: dict[str, EventQueue] = {}
self._proxy_queue: dict[str, EventQueue] = {}
self._lock = asyncio.Lock()
self._pubsub = redis_client.pubsub()
self._relay_channel_name = relay_channel_key_prefix
self._background_tasks: dict[str, Task] = {}
self._task_registry_name = task_registry_key

async def _listen_and_relay(self, task_id: str):
c = EventConsumer(self._local_queue[task_id])
async for event in c.consume_all():
await self._redis.publish(
self._task_channel_name(task_id),
event.model_dump_json(exclude_none=True),
)

def _task_channel_name(self, task_id: str):
return self._relay_channel_name + task_id

async def _has_task_id(self, task_id: str):
ret = await self._redis.sismember(self._task_registry_name, task_id)

Check failure on line 52 in src/a2a/server/events/redis_queue_manager.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`sismember` is not a recognized word. (unrecognized-spelling)
return ret

async def _register_task_id(self, task_id: str):
await self._redis.sadd(self._task_registry_name, task_id)

Check failure on line 56 in src/a2a/server/events/redis_queue_manager.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`sadd` is not a recognized word. (unrecognized-spelling)
self._background_tasks[task_id] = asyncio.create_task(
self._listen_and_relay(task_id)
)

async def _remove_task_id(self, task_id: str):
if task_id in self._background_tasks:
self._background_tasks[task_id].cancel(
'task_id is closed: ' + task_id
)
return await self._redis.srem(self._task_registry_name, task_id)

Check failure on line 66 in src/a2a/server/events/redis_queue_manager.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`srem` is not a recognized word. (unrecognized-spelling)

async def _subscribe_remote_task_events(self, task_id: str):
await self._pubsub.subscribe(
**{
self._task_channel_name(task_id): partial(
self._relay_remote_events, task_id
)
}
)

def _unsubscribe_remote_task_events(self, task_id: str):
self._pubsub.unsubscribe(self._task_channel_name(task_id))

def _relay_remote_events(self, task_id: str, event_json: str):
if task_id in self._proxy_queue:
event = Event.model_validate_json(event_json)
self._proxy_queue[task_id].enqueue_event(event)

async def add(self, task_id: str, queue: EventQueue) -> None:
async with self._lock:
if await self._has_task_id(task_id):
raise TaskQueueExists()
self._local_queue[task_id] = queue
await self._register_task_id(task_id)

async def get(self, task_id: str) -> EventQueue | None:
async with self._lock:
# lookup locally
if task_id in self._local_queue:
return self._local_queue[task_id]
# lookup globally
if await self._has_task_id(task_id):
if task_id not in self._proxy_queue:
queue = EventQueue()
self._proxy_queue[task_id] = queue
await self._subscribe_remote_task_events(task_id)
return self._proxy_queue[task_id]
return None

async def tap(self, task_id: str) -> EventQueue | None:
event_queue = await self.get(task_id)
if event_queue:
return event_queue.tap()
return None

async def close(self, task_id: str) -> None:
async with self._lock:
if task_id in self._local_queue:
# close locally
queue = self._local_queue.pop(task_id)
await queue.close()
# remove from global registry if a local queue is closed
await self._remove_task_id(task_id)
return

if task_id in self._proxy_queue:
# close proxy queue
queue = self._proxy_queue.pop(task_id)
await queue.close()
# unsubscribe from remote, but don't remove from global registry
self._unsubscribe_remote_task_events(task_id)
return

raise NoTaskQueue()

async def create_or_tap(self, task_id: str) -> EventQueue:
async with self._lock:
if await self._has_task_id(task_id):
# if it's a local queue, tap directly
if task_id in self._local_queue:
return self._local_queue[task_id].tap()

# if it's a proxy queue, tap the proxy
if task_id in self._proxy_queue:
return self._proxy_queue[task_id].tap()

# if the proxy is not created, create the proxy and return
queue = EventQueue()
self._proxy_queue[task_id] = queue
await self._subscribe_remote_task_events(task_id)
return self._proxy_queue[task_id]
# the task doesn't exist before, create a local queue
queue = EventQueue()
self._local_queue[task_id] = queue
await self._register_task_id(task_id)
return queue
126 changes: 126 additions & 0 deletions tests/server/events/test_redis_queue_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import asyncio

from unittest.mock import MagicMock

import pytest

from fakeredis import FakeAsyncRedis

from a2a.server.events import EventQueue, TaskQueueExists
from a2a.server.events.redis_queue_manager import RedisQueueManager


class TestRedisQueueManager:
@pytest.fixture
def redis(self):
return FakeAsyncRedis()

@pytest.fixture
def queue_manager(self, redis):
return RedisQueueManager(redis)

@pytest.fixture
def event_queue(self):
queue = MagicMock(spec=EventQueue)
# Mock the tap method to return itself
queue.tap.return_value = queue
return queue

@pytest.mark.asyncio
async def test_init(self, queue_manager):
assert queue_manager._local_queue == {}
assert queue_manager._proxy_queue == {}
assert isinstance(queue_manager._lock, asyncio.Lock)

@pytest.mark.asyncio
async def test_add_new_queue(self, queue_manager, event_queue):
"""Test adding a new queue to the manager."""
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)
assert queue_manager._local_queue[task_id] == event_queue

@pytest.mark.asyncio
async def test_add_existing_queue(self, queue_manager, event_queue):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

with pytest.raises(TaskQueueExists):
await queue_manager.add(task_id, event_queue)

@pytest.mark.asyncio
async def test_get_existing_queue(self, queue_manager, event_queue):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

result = await queue_manager.get(task_id)
assert result == event_queue

@pytest.mark.asyncio
async def test_get_nonexistent_queue(self, queue_manager):
result = await queue_manager.get('nonexistent_task_id')
assert result is None

@pytest.mark.asyncio
async def test_tap_existing_queue(self, queue_manager, event_queue):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

result = await queue_manager.tap(task_id)
assert result == event_queue
event_queue.tap.assert_called_once()

@pytest.mark.asyncio
async def test_tap_nonexistent_queue(self, queue_manager):
result = await queue_manager.tap('nonexistent_task_id')
assert result is None

@pytest.mark.asyncio
async def test_close_existing_queue(self, queue_manager, event_queue):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

await queue_manager.close(task_id)
assert task_id not in queue_manager._local_queue

@pytest.mark.asyncio
async def test_create_or_tap_existing_queue(
self, queue_manager, event_queue
):
task_id = 'test_task_id'
await queue_manager.add(task_id, event_queue)

result = await queue_manager.create_or_tap(task_id)

assert result == event_queue
event_queue.tap.assert_called_once()

@pytest.mark.asyncio
async def test_concurrency(self, queue_manager):
async def add_task(task_id):
queue = EventQueue()
await queue_manager.add(task_id, queue)
return task_id

async def get_task(task_id):
return await queue_manager.get(task_id)

# Create 10 different task IDs
task_ids = [f'task_{i}' for i in range(10)]

# Add tasks concurrently
add_tasks = [add_task(task_id) for task_id in task_ids]
added_task_ids = await asyncio.gather(*add_tasks)

# Verify all tasks were added
assert set(added_task_ids) == set(task_ids)

# Get tasks concurrently
get_tasks = [get_task(task_id) for task_id in task_ids]
queues = await asyncio.gather(*get_tasks)

# Verify all queues are not None
assert all(queue is not None for queue in queues)

# Verify all tasks are in the manager
for task_id in task_ids:
assert task_id in queue_manager._local_queue
Loading