Skip to content

Commit 84c6bef

Browse files
author
long.qul
committed
feat(server): implement TTL for task IDs in Redis
- Replace Redis set with sorted set to store task IDs with timestamp - Add TTL update mechanism for active task IDs - Implement periodic cleanup of expired task IDs - Update task registration and removal to use new sorted set structure
1 parent f75bf8b commit 84c6bef

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

.github/actions/spelling/allow.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,8 @@ typeerror
7474
vulnz
7575
sadd
7676
sismember
77-
srem
77+
srem
78+
zadd
79+
zscore
80+
zrem
81+
zremrangebyscore

src/a2a/server/events/redis_queue_manager.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
22
import logging
3+
import random
4+
import time
35

46
from asyncio import Task
5-
from functools import partial
6-
from typing import Any, Dict, Optional
7+
from typing import Any
78

8-
from pydantic import ValidationError, TypeAdapter
9+
from pydantic import TypeAdapter
910
from redis.asyncio import Redis
1011

1112
from a2a.server.events import (
@@ -17,9 +18,12 @@
1718
TaskQueueExists,
1819
)
1920

21+
2022
logger = logging.getLogger(__name__)
2123

2224

25+
CLEAN_EXPIRED_PROBABILITY = 0.5
26+
2327
class RedisQueueManager(QueueManager):
2428
"""This implements the `QueueManager` interface using Redis for event.
2529
@@ -36,6 +40,7 @@ def __init__(
3640
redis_client: Redis,
3741
relay_channel_key_prefix: str = 'a2a.event.relay.',
3842
task_registry_key: str = 'a2a.event.registry',
43+
task_id_ttl_in_second: int = 60 * 60 * 24,
3944
):
4045
self._redis = redis_client
4146
self._local_queue: dict[str, EventQueue] = {}
@@ -45,16 +50,18 @@ def __init__(
4550
self._relay_channel_name = relay_channel_key_prefix
4651
self._background_tasks: dict[str, Task] = {}
4752
self._task_registry_name = task_registry_key
48-
self._pubsub_listener_task: Optional[Task] = None
53+
self._pubsub_listener_task: Task | None = None
54+
self._task_id_ttl_in_second = task_id_ttl_in_second
4955

5056
def _task_channel_name(self, task_id: str) -> str:
5157
return self._relay_channel_name + task_id
5258

5359
async def _has_task_id(self, task_id: str) -> bool:
54-
ret = await self._redis.sismember(self._task_registry_name, task_id)
55-
return ret == 1
60+
ret = await self._redis.zscore(self._task_registry_name, task_id)
61+
return ret is not None
5662

5763
async def _register_task_id(self, task_id: str) -> None:
64+
assert await self._redis.zadd(self._task_registry_name, {task_id: time.time()}, nx=True), 'task_id should not exist in global registry: ' + task_id
5865
task_started_event = asyncio.Event()
5966
async def _wrapped_listen_and_relay() -> None:
6067
task_started_event.set()
@@ -65,20 +72,36 @@ async def _wrapped_listen_and_relay() -> None:
6572
self._task_channel_name(task_id),
6673
event.model_dump_json(exclude_none=True),
6774
)
75+
# update TTL for task_id
76+
await self._update_task_id_ttl(task_id)
77+
# clean expired task_ids with certain possibility
78+
if random.random() < CLEAN_EXPIRED_PROBABILITY:
79+
await self._clean_expired_task_ids()
6880

6981
self._background_tasks[task_id] = asyncio.create_task(
7082
_wrapped_listen_and_relay()
7183
)
7284
await task_started_event.wait()
73-
await self._redis.sadd(self._task_registry_name, task_id)
7485
logger.debug(f'Started to listen and relay events for task {task_id}')
7586

7687
async def _remove_task_id(self, task_id: str) -> bool:
7788
if task_id in self._background_tasks:
7889
self._background_tasks[task_id].cancel(
7990
'task_id is closed: ' + task_id
8091
)
81-
return await self._redis.srem(self._task_registry_name, task_id) == 1
92+
return await self._redis.zrem(self._task_registry_name, task_id) == 1
93+
94+
async def _update_task_id_ttl(self, task_id: str) -> bool:
95+
ret = await self._redis.zadd(
96+
self._task_registry_name,
97+
{task_id: time.time()},
98+
xx=True
99+
)
100+
return ret is not None
101+
102+
async def _clean_expired_task_ids(self) -> None:
103+
count = await self._redis.zremrangebyscore(self._task_registry_name, 0, time.time() - self._task_id_ttl_in_second)
104+
logger.debug(f'Removed {count} expired task ids')
82105

83106
async def _subscribe_remote_task_events(self, task_id: str) -> None:
84107
channel_id = self._task_channel_name(task_id)
@@ -91,11 +114,11 @@ async def _subscribe_remote_task_events(self, task_id: str) -> None:
91114

92115
logger.debug(f"Subscribed for remote events for task {task_id}")
93116

94-
async def _consume_pubsub_messages(self):
117+
async def _consume_pubsub_messages(self) -> None:
95118
async for _ in self._pubsub.listen():
96119
pass
97120

98-
async def _relay_remote_events(self, subscription_event) -> None:
121+
async def _relay_remote_events(self, subscription_event: dict[str, Any]) -> None:
99122
if 'channel' not in subscription_event or 'data' not in subscription_event:
100123
logger.warning(f"channel or data is absent in subscription event: {subscription_event}")
101124
return

0 commit comments

Comments
 (0)