11import asyncio
22import logging
3+ import random
4+ import time
35
46from asyncio import Task
5- from functools import partial
6- from typing import Any , Dict , Optional
7+ from typing import Any
78
8- from pydantic import ValidationError , TypeAdapter
9+ from pydantic import TypeAdapter
910from redis .asyncio import Redis
1011
1112from a2a .server .events import (
1718 TaskQueueExists ,
1819)
1920
21+
2022logger = logging .getLogger (__name__ )
2123
2224
25+ CLEAN_EXPIRED_PROBABILITY = 0.5
26+
2327class RedisQueueManager (QueueManager ):
2428 """This implements the `QueueManager` interface using Redis for event.
2529
@@ -36,6 +40,7 @@ def __init__(
3640 redis_client : Redis ,
3741 relay_channel_key_prefix : str = 'a2a.event.relay.' ,
3842 task_registry_key : str = 'a2a.event.registry' ,
43+ task_id_ttl_in_second : int = 60 * 60 * 24 ,
3944 ):
4045 self ._redis = redis_client
4146 self ._local_queue : dict [str , EventQueue ] = {}
@@ -45,16 +50,18 @@ def __init__(
4550 self ._relay_channel_name = relay_channel_key_prefix
4651 self ._background_tasks : dict [str , Task ] = {}
4752 self ._task_registry_name = task_registry_key
48- self ._pubsub_listener_task : Optional [Task ] = None
53+ self ._pubsub_listener_task : Task | None = None
54+ self ._task_id_ttl_in_second = task_id_ttl_in_second
4955
5056 def _task_channel_name (self , task_id : str ) -> str :
5157 return self ._relay_channel_name + task_id
5258
5359 async def _has_task_id (self , task_id : str ) -> bool :
54- ret = await self ._redis .sismember (self ._task_registry_name , task_id )
55- return ret == 1
60+ ret = await self ._redis .zscore (self ._task_registry_name , task_id )
61+ return ret is not None
5662
5763 async def _register_task_id (self , task_id : str ) -> None :
64+ assert await self ._redis .zadd (self ._task_registry_name , {task_id : time .time ()}, nx = True ), 'task_id should not exist in global registry: ' + task_id
5865 task_started_event = asyncio .Event ()
5966 async def _wrapped_listen_and_relay () -> None :
6067 task_started_event .set ()
@@ -65,20 +72,36 @@ async def _wrapped_listen_and_relay() -> None:
6572 self ._task_channel_name (task_id ),
6673 event .model_dump_json (exclude_none = True ),
6774 )
75+ # update TTL for task_id
76+ await self ._update_task_id_ttl (task_id )
77+ # clean expired task_ids with certain possibility
78+ if random .random () < CLEAN_EXPIRED_PROBABILITY :
79+ await self ._clean_expired_task_ids ()
6880
6981 self ._background_tasks [task_id ] = asyncio .create_task (
7082 _wrapped_listen_and_relay ()
7183 )
7284 await task_started_event .wait ()
73- await self ._redis .sadd (self ._task_registry_name , task_id )
7485 logger .debug (f'Started to listen and relay events for task { task_id } ' )
7586
7687 async def _remove_task_id (self , task_id : str ) -> bool :
7788 if task_id in self ._background_tasks :
7889 self ._background_tasks [task_id ].cancel (
7990 'task_id is closed: ' + task_id
8091 )
81- return await self ._redis .srem (self ._task_registry_name , task_id ) == 1
92+ return await self ._redis .zrem (self ._task_registry_name , task_id ) == 1
93+
94+ async def _update_task_id_ttl (self , task_id : str ) -> bool :
95+ ret = await self ._redis .zadd (
96+ self ._task_registry_name ,
97+ {task_id : time .time ()},
98+ xx = True
99+ )
100+ return ret is not None
101+
102+ async def _clean_expired_task_ids (self ) -> None :
103+ count = await self ._redis .zremrangebyscore (self ._task_registry_name , 0 , time .time () - self ._task_id_ttl_in_second )
104+ logger .debug (f'Removed { count } expired task ids' )
82105
83106 async def _subscribe_remote_task_events (self , task_id : str ) -> None :
84107 channel_id = self ._task_channel_name (task_id )
@@ -91,11 +114,11 @@ async def _subscribe_remote_task_events(self, task_id: str) -> None:
91114
92115 logger .debug (f"Subscribed for remote events for task { task_id } " )
93116
94- async def _consume_pubsub_messages (self ):
117+ async def _consume_pubsub_messages (self ) -> None :
95118 async for _ in self ._pubsub .listen ():
96119 pass
97120
98- async def _relay_remote_events (self , subscription_event ) -> None :
121+ async def _relay_remote_events (self , subscription_event : dict [ str , Any ] ) -> None :
99122 if 'channel' not in subscription_event or 'data' not in subscription_event :
100123 logger .warning (f"channel or data is absent in subscription event: { subscription_event } " )
101124 return
0 commit comments