|
| 1 | +import asyncio |
| 2 | +from asyncio import Task |
| 3 | +from functools import partial |
| 4 | +from typing import Dict |
| 5 | + |
| 6 | +from redis.asyncio import Redis |
| 7 | + |
| 8 | +from a2a.server.events import QueueManager, EventQueue, TaskQueueExists, Event, EventConsumer, NoTaskQueue |
| 9 | + |
| 10 | + |
| 11 | +class RedisQueueManager(QueueManager): |
| 12 | + """ |
| 13 | + This implements the `QueueManager` interface using Redis for event |
| 14 | + queues. Primary jobs: |
| 15 | + 1. Broadcast local events to proxy queues in other processes using redis pubsub |
| 16 | + 2. Subscribe event messages from redis pubsub and replay to local proxy queues |
| 17 | + """ |
| 18 | + |
| 19 | + def __init__(self, redis_client: Redis, |
| 20 | + relay_channel_key_prefix: str = "a2a.event.relay.", |
| 21 | + task_registry_key: str = "a2a.event.registry" |
| 22 | + ): |
| 23 | + self._redis = redis_client |
| 24 | + self._local_queue: dict[str, EventQueue] = {} |
| 25 | + self._proxy_queue: dict[str, EventQueue] = {} |
| 26 | + self._lock = asyncio.Lock() |
| 27 | + self._pubsub = redis_client.pubsub() |
| 28 | + self._relay_channel_name = relay_channel_key_prefix |
| 29 | + self._background_tasks: Dict[str, Task] = {} |
| 30 | + self._task_registry_name = task_registry_key |
| 31 | + |
| 32 | + async def _listen_and_relay(self, task_id: str): |
| 33 | + c = EventConsumer(self._local_queue[task_id]) |
| 34 | + async for event in c.consume_all(): |
| 35 | + await self._redis.publish(self._task_channel_name(task_id), event.model_dump_json(exclude_none=True)) |
| 36 | + |
| 37 | + def _task_channel_name(self, task_id: str): |
| 38 | + return self._relay_channel_name + task_id |
| 39 | + |
| 40 | + async def _has_task_id(self, task_id: str): |
| 41 | + ret = await self._redis.sismember(self._task_registry_name, task_id) |
| 42 | + return ret |
| 43 | + |
| 44 | + async def _register_task_id(self, task_id: str): |
| 45 | + await self._redis.sadd(self._task_registry_name, task_id) |
| 46 | + self._background_tasks[task_id] = asyncio.create_task(self._listen_and_relay(task_id)) |
| 47 | + |
| 48 | + async def _remove_task_id(self, task_id: str): |
| 49 | + if task_id in self._background_tasks: |
| 50 | + self._background_tasks[task_id].cancel("task_id is closed: " + task_id) |
| 51 | + return await self._redis.srem(self._task_registry_name, task_id) |
| 52 | + |
| 53 | + async def _subscribe_remote_task_events(self, task_id: str): |
| 54 | + await self._pubsub.subscribe(**{ |
| 55 | + self._task_channel_name(task_id): partial(self._relay_remote_events, task_id) |
| 56 | + }) |
| 57 | + |
| 58 | + def _unsubscribe_remote_task_events(self, task_id: str): |
| 59 | + self._pubsub.unsubscribe(self._task_channel_name(task_id)) |
| 60 | + |
| 61 | + def _relay_remote_events(self, task_id: str , event_json: str): |
| 62 | + if task_id in self._proxy_queue: |
| 63 | + event = Event.model_validate_json(event_json) |
| 64 | + self._proxy_queue[task_id].enqueue_event(event) |
| 65 | + |
| 66 | + async def add(self, task_id: str, queue: EventQueue) -> None: |
| 67 | + async with self._lock: |
| 68 | + if await self._has_task_id(task_id): |
| 69 | + raise TaskQueueExists() |
| 70 | + self._local_queue[task_id] = queue |
| 71 | + await self._register_task_id(task_id) |
| 72 | + |
| 73 | + async def get(self, task_id: str) -> EventQueue | None: |
| 74 | + async with self._lock: |
| 75 | + # lookup locally |
| 76 | + if task_id in self._local_queue: |
| 77 | + return self._local_queue[task_id] |
| 78 | + # lookup globally |
| 79 | + if await self._has_task_id(task_id): |
| 80 | + if task_id not in self._proxy_queue: |
| 81 | + queue = EventQueue() |
| 82 | + self._proxy_queue[task_id] = queue |
| 83 | + await self._subscribe_remote_task_events(task_id) |
| 84 | + return self._proxy_queue[task_id] |
| 85 | + return None |
| 86 | + |
| 87 | + async def tap(self, task_id: str) -> EventQueue | None: |
| 88 | + event_queue = await self.get(task_id) |
| 89 | + if event_queue: |
| 90 | + return event_queue.tap() |
| 91 | + return None |
| 92 | + |
| 93 | + async def close(self, task_id: str) -> None: |
| 94 | + async with self._lock: |
| 95 | + if task_id in self._local_queue: |
| 96 | + # close locally |
| 97 | + queue = self._local_queue.pop(task_id) |
| 98 | + await queue.close() |
| 99 | + # remove from global registry if a local queue is closed |
| 100 | + await self._remove_task_id(task_id) |
| 101 | + return None |
| 102 | + |
| 103 | + if task_id in self._proxy_queue: |
| 104 | + # close proxy queue |
| 105 | + queue = self._proxy_queue.pop(task_id) |
| 106 | + await queue.close() |
| 107 | + # unsubscribe from remote, but don't remove from global registry |
| 108 | + self._unsubscribe_remote_task_events(task_id) |
| 109 | + return None |
| 110 | + |
| 111 | + raise NoTaskQueue() |
| 112 | + |
| 113 | + async def create_or_tap(self, task_id: str) -> EventQueue: |
| 114 | + async with self._lock: |
| 115 | + if await self._has_task_id(task_id): |
| 116 | + # if it's a local queue, tap directly |
| 117 | + if task_id in self._local_queue: |
| 118 | + return self._local_queue[task_id].tap() |
| 119 | + |
| 120 | + # if it's a proxy queue, tap the proxy |
| 121 | + if task_id in self._proxy_queue: |
| 122 | + return self._proxy_queue[task_id].tap() |
| 123 | + |
| 124 | + # if the proxy is not created, create the proxy and return |
| 125 | + queue = EventQueue() |
| 126 | + self._proxy_queue[task_id] = queue |
| 127 | + await self._subscribe_remote_task_events(task_id) |
| 128 | + return self._proxy_queue[task_id] |
| 129 | + else: |
| 130 | + # the task doesn't exist before, create a local queue |
| 131 | + queue = EventQueue() |
| 132 | + self._local_queue[task_id] = queue |
| 133 | + await self._register_task_id(task_id) |
| 134 | + return queue |
0 commit comments