11import asyncio
22import logging
3- import random
4- import time
3+ import uuid
54
65from asyncio import Task
76from typing import Any
2221logger = logging .getLogger (__name__ )
2322
2423
25- CLEAN_EXPIRED_PROBABILITY = 0.5
26-
2724class RedisQueueManager (QueueManager ):
2825 """This implements the `QueueManager` interface using Redis for event.
2926
@@ -33,6 +30,8 @@ class RedisQueueManager(QueueManager):
3330 redis_client(Redis): asyncio redis connection.
3431 relay_channel_key_prefix(str): prefix for pubsub channel key generation.
3532 task_registry_key(str): key for set data where stores active `task_id`s.
33+ task_id_ttl_in_second: TTL for task id in global registry
34+ node_id: A unique id to be associated with task id in global registry. If node id is not matched, events won't be populated to queues in other `RedisQueueManager`s.
3635 """
3736
3837 def __init__ (
@@ -41,6 +40,7 @@ def __init__(
4140 relay_channel_key_prefix : str = 'a2a.event.relay.' ,
4241 task_registry_key : str = 'a2a.event.registry' ,
4342 task_id_ttl_in_second : int = 60 * 60 * 24 ,
43+ node_id : str = str (uuid .uuid4 ()),
4444 ):
4545 self ._redis = redis_client
4646 self ._local_queue : dict [str , EventQueue ] = {}
@@ -52,31 +52,60 @@ def __init__(
5252 self ._task_registry_name = task_registry_key
5353 self ._pubsub_listener_task : Task | None = None
5454 self ._task_id_ttl_in_second = task_id_ttl_in_second
55+ self ._node_id = node_id
5556
5657 def _task_channel_name (self , task_id : str ) -> str :
5758 return self ._relay_channel_name + task_id
5859
5960 async def _has_task_id (self , task_id : str ) -> bool :
60- ret = await self ._redis .zscore (self ._task_registry_name , task_id )
61+ ret = await self ._redis .hget (self ._task_registry_name , task_id )
6162 return ret is not None
6263
6364 async def _register_task_id (self , task_id : str ) -> None :
64- assert await self ._redis .zadd (self ._task_registry_name , {task_id : time .time ()}, nx = True ), 'task_id should not exist in global registry: ' + task_id
65+ await self ._redis .hsetex (
66+ name = self ._task_registry_name ,
67+ key = task_id ,
68+ value = self ._node_id ,
69+ ex = self ._task_id_ttl_in_second ,
70+ )
71+ logger .debug (
72+ f'Registered task_id { task_id } to node { self ._node_id } in registry.'
73+ )
6574 task_started_event = asyncio .Event ()
75+
6676 async def _wrapped_listen_and_relay () -> None :
6777 task_started_event .set ()
6878 c = EventConsumer (self ._local_queue [task_id ].tap ())
6979 async for event in c .consume_all ():
70- logger .debug (f'Publishing event for task { task_id } in QM { self } : { event } ' )
71- await self ._redis .publish (
72- self ._task_channel_name (task_id ),
73- event .model_dump_json (exclude_none = True ),
80+ logger .debug (
81+ f'Publishing event for task { task_id } in QM { self } : { event } '
7482 )
75- # update TTL for task_id
76- await self ._update_task_id_ttl (task_id )
77- # clean expired task_ids with certain possibility
78- if random .random () < CLEAN_EXPIRED_PROBABILITY :
79- await self ._clean_expired_task_ids ()
83+ expected_node_id = await self ._redis .hget (
84+ self ._task_registry_name , task_id
85+ )
86+ expected_node_id = (
87+ expected_node_id .decode ('utf-8' )
88+ if hasattr (expected_node_id , 'decode' )
89+ else expected_node_id
90+ )
91+ if expected_node_id == self ._node_id :
92+ # publish message
93+ await self ._redis .publish (
94+ self ._task_channel_name (task_id ),
95+ event .model_dump_json (exclude_none = True ),
96+ )
97+ # update TTL for task_id
98+ await self ._redis .hsetex (
99+ name = self ._task_registry_name ,
100+ key = task_id ,
101+ value = self ._node_id ,
102+ ex = self ._task_id_ttl_in_second ,
103+ )
104+ else :
105+ logger .error (
106+ f'Task { task_id } is not registered on this node. Expected node id: { expected_node_id } '
107+ )
108+ break
80109
81110 self ._background_tasks [task_id ] = asyncio .create_task (
82111 _wrapped_listen_and_relay ()
@@ -89,55 +118,56 @@ async def _remove_task_id(self, task_id: str) -> bool:
89118 self ._background_tasks [task_id ].cancel (
90119 'task_id is closed: ' + task_id
91120 )
92- return await self ._redis .zrem (self ._task_registry_name , task_id ) == 1
93-
94- async def _update_task_id_ttl (self , task_id : str ) -> bool :
95- ret = await self ._redis .zadd (
96- self ._task_registry_name ,
97- {task_id : time .time ()},
98- xx = True
99- )
100- return ret is not None
101-
102- async def _clean_expired_task_ids (self ) -> None :
103- count = await self ._redis .zremrangebyscore (self ._task_registry_name , 0 , time .time () - self ._task_id_ttl_in_second )
104- logger .debug (f'Removed { count } expired task ids' )
121+ return await self ._redis .hdel (self ._task_registry_name , task_id ) == 1
105122
106123 async def _subscribe_remote_task_events (self , task_id : str ) -> None :
107124 channel_id = self ._task_channel_name (task_id )
108125 await self ._pubsub .subscribe (** {channel_id : self ._relay_remote_events })
109-
110126 # this is a global listener to handle incoming pubsub events
111127 if not self ._pubsub_listener_task :
112128 logger .debug ('Creating pubsub listener task.' )
113- self ._pubsub_listener_task = asyncio .create_task (self ._consume_pubsub_messages ())
114-
115- logger .debug (f"Subscribed for remote events for task { task_id } " )
129+ self ._pubsub_listener_task = asyncio .create_task (
130+ self ._consume_pubsub_messages ()
131+ )
132+ logger .debug (f'Subscribed for remote events for task { task_id } ' )
116133
117134 async def _consume_pubsub_messages (self ) -> None :
118135 async for _ in self ._pubsub .listen ():
119136 pass
120137
121- async def _relay_remote_events (self , subscription_event : dict [str , Any ]) -> None :
122- if 'channel' not in subscription_event or 'data' not in subscription_event :
123- logger .warning (f"channel or data is absent in subscription event: { subscription_event } " )
138+ async def _relay_remote_events (
139+ self , subscription_event : dict [str , Any ]
140+ ) -> None :
141+ if (
142+ 'channel' not in subscription_event
143+ or 'data' not in subscription_event
144+ ):
145+ logger .warning (
146+ f'channel or data is absent in subscription event: { subscription_event } '
147+ )
124148 return
125149
126150 channel_id : str = subscription_event ['channel' ].decode ('utf-8' )
127151 data_string : str = subscription_event ['data' ].decode ('utf-8' )
128152 task_id = channel_id .split ('.' )[- 1 ]
129153 if task_id not in self ._proxy_queue :
130- logger .warning (f" task_id { task_id } not found in proxy queue" )
154+ logger .warning (f' task_id { task_id } not found in proxy queue' )
131155 return
132156
133157 try :
134- logger .debug (f"Received event for task_id { task_id } in QM { self } : { data_string } " )
158+ logger .debug (
159+ f'Received event for task_id { task_id } in QM { self } : { data_string } '
160+ )
135161 event : Event = TypeAdapter (Event ).validate_json (data_string )
136162 except Exception as e :
137- logger .warning (f"Failed to parse event from subscription event: { subscription_event } : { e } " )
163+ logger .warning (
164+ f'Failed to parse event from subscription event: { subscription_event } : { e } '
165+ )
138166 return
139167
140- logger .debug (f"Enqueuing event for task_id { task_id } in QM { self } : { event } " )
168+ logger .debug (
169+ f'Enqueuing event for task_id { task_id } in QM { self } : { event } '
170+ )
141171 await self ._proxy_queue [task_id ].enqueue_event (event )
142172
143173 async def _unsubscribe_remote_task_events (self , task_id : str ) -> None :
@@ -148,7 +178,6 @@ async def _unsubscribe_remote_task_events(self, task_id: str) -> None:
148178 self ._pubsub_listener_task .cancel ()
149179 self ._pubsub_listener_task = None
150180
151-
152181 async def add (self , task_id : str , queue : EventQueue ) -> None :
153182 """Add a new local event queue for the specified task.
154183
@@ -159,13 +188,13 @@ async def add(self, task_id: str, queue: EventQueue) -> None:
159188 Raises:
160189 TaskQueueExists: If a queue for the task already exists.
161190 """
162- logger .debug (f" add { task_id } " )
191+ logger .debug (f' add { task_id } ' )
163192 async with self ._lock :
164193 if await self ._has_task_id (task_id ):
165194 raise TaskQueueExists ()
166195 self ._local_queue [task_id ] = queue
167196 await self ._register_task_id (task_id )
168- logger .debug (f" Local queue is created for task { task_id } " )
197+ logger .debug (f' Local queue is created for task { task_id } ' )
169198
170199 async def get (self , task_id : str ) -> EventQueue | None :
171200 """Get the event queue associated with the given task ID.
@@ -180,22 +209,24 @@ async def get(self, task_id: str) -> EventQueue | None:
180209 Returns:
181210 EventQueue | None: The event queue if found, otherwise None.
182211 """
183- logger .debug (f" get { task_id } " )
212+ logger .debug (f' get { task_id } ' )
184213 async with self ._lock :
185214 # lookup locally
186215 if task_id in self ._local_queue :
187- logger .debug (f" Got local queue for task_id { task_id } " )
216+ logger .debug (f' Got local queue for task_id { task_id } ' )
188217 return self ._local_queue [task_id ]
189218 # lookup globally
190219 if await self ._has_task_id (task_id ):
191220 if task_id not in self ._proxy_queue :
192- logger .debug (f" Creating proxy queue for { task_id } " )
221+ logger .debug (f' Creating proxy queue for { task_id } ' )
193222 queue = EventQueue ()
194223 self ._proxy_queue [task_id ] = queue
195224 await self ._subscribe_remote_task_events (task_id )
196- logger .debug (f" Got proxy queue for task_id { task_id } " )
225+ logger .debug (f' Got proxy queue for task_id { task_id } ' )
197226 return self ._proxy_queue [task_id ]
198- logger .warning (f"Attempted to get non-existing queue for task { task_id } " )
227+ logger .warning (
228+ f'Attempted to get non-existing queue for task { task_id } '
229+ )
199230 return None
200231
201232 async def tap (self , task_id : str ) -> EventQueue | None :
@@ -207,7 +238,7 @@ async def tap(self, task_id: str) -> EventQueue | None:
207238 Returns:
208239 EventQueue | None: A new reference to the event queue if it exists, otherwise None.
209240 """
210- logger .debug (f" tap { task_id } " )
241+ logger .debug (f' tap { task_id } ' )
211242 event_queue = await self .get (task_id )
212243 if event_queue :
213244 logger .debug (f'Tapping event queue for task: { task_id } ' )
@@ -227,15 +258,15 @@ async def close(self, task_id: str) -> None:
227258 Raises:
228259 NoTaskQueue: If no queue exists for the given task ID.
229260 """
230- logger .debug (f" close { task_id } " )
261+ logger .debug (f' close { task_id } ' )
231262 async with self ._lock :
232263 if task_id in self ._local_queue :
233264 # remove from global registry if a local queue is closed
234265 await self ._remove_task_id (task_id )
235266 # close locally
236267 queue = self ._local_queue .pop (task_id )
237268 await queue .close ()
238- logger .debug (f" Closing local queue for task { task_id } " )
269+ logger .debug (f' Closing local queue for task { task_id } ' )
239270 return
240271
241272 if task_id in self ._proxy_queue :
@@ -244,10 +275,12 @@ async def close(self, task_id: str) -> None:
244275 await queue .close ()
245276 # unsubscribe from remote, but don't remove from global registry
246277 await self ._unsubscribe_remote_task_events (task_id )
247- logger .debug (f" Closing proxy queue for task { task_id } " )
278+ logger .debug (f' Closing proxy queue for task { task_id } ' )
248279 return
249280
250- logger .warning (f"Attempted to close non-existing queue found for task { task_id } " )
281+ logger .warning (
282+ f'Attempted to close non-existing queue found for task { task_id } '
283+ )
251284 raise NoTaskQueue ()
252285
253286 async def create_or_tap (self , task_id : str ) -> EventQueue :
@@ -262,28 +295,25 @@ async def create_or_tap(self, task_id: str) -> EventQueue:
262295 Returns:
263296 EventQueue: An event queue associated with the given task ID.
264297 """
265- logger .debug (f" create_or_tap { task_id } " )
298+ logger .debug (f' create_or_tap { task_id } ' )
266299 async with self ._lock :
267300 if await self ._has_task_id (task_id ):
268301 # if it's a local queue, tap directly
269302 if task_id in self ._local_queue :
270- logger .debug (f" Tapping a local queue for task { task_id } " )
303+ logger .debug (f' Tapping a local queue for task { task_id } ' )
271304 return self ._local_queue [task_id ].tap ()
272305
273306 # if it's a proxy queue, tap the proxy
274- if task_id in self ._proxy_queue :
275- logger .debug (f"Tapping a proxy queue for task { task_id } " )
276- return self ._proxy_queue [task_id ].tap ()
277-
278- # if the proxy is not created, create the proxy and return
279- queue = EventQueue ()
280- self ._proxy_queue [task_id ] = queue
281- await self ._subscribe_remote_task_events (task_id )
282- logger .debug (f"Creating a proxy queue for task { task_id } " )
283- return self ._proxy_queue [task_id ]
307+ if task_id not in self ._proxy_queue :
308+ # if the proxy is not created, create the proxy
309+ queue = EventQueue ()
310+ self ._proxy_queue [task_id ] = queue
311+ await self ._subscribe_remote_task_events (task_id )
312+ logger .debug (f'Tapping a proxy queue for task { task_id } ' )
313+ return self ._proxy_queue [task_id ].tap ()
284314 # the task doesn't exist before, create a local queue
285315 queue = EventQueue ()
286316 self ._local_queue [task_id ] = queue
287317 await self ._register_task_id (task_id )
288- logger .debug (f" Creating a local queue for task { task_id } " )
318+ logger .debug (f' Creating a local queue for task { task_id } ' )
289319 return queue
0 commit comments