11import asyncio
2+ import logging
23
34from asyncio import Task
45from functools import partial
6+ from typing import Any , Dict , Optional
57
8+ from pydantic import ValidationError , TypeAdapter
69from redis .asyncio import Redis
710
811from a2a .server .events import (
1417 TaskQueueExists ,
1518)
1619
20+ logger = logging .getLogger (__name__ )
21+
1722
1823class RedisQueueManager (QueueManager ):
1924 """This implements the `QueueManager` interface using Redis for event.
@@ -40,14 +45,7 @@ def __init__(
4045 self ._relay_channel_name = relay_channel_key_prefix
4146 self ._background_tasks : dict [str , Task ] = {}
4247 self ._task_registry_name = task_registry_key
43-
44- async def _listen_and_relay (self , task_id : str ) -> None :
45- c = EventConsumer (self ._local_queue [task_id ])
46- async for event in c .consume_all ():
47- await self ._redis .publish (
48- self ._task_channel_name (task_id ),
49- event .model_dump_json (exclude_none = True ),
50- )
48+ self ._pubsub_listener_task : Optional [Task ] = None
5149
5250 def _task_channel_name (self , task_id : str ) -> str :
5351 return self ._relay_channel_name + task_id
@@ -57,10 +55,23 @@ async def _has_task_id(self, task_id: str) -> bool:
5755 return ret == 1
5856
5957 async def _register_task_id (self , task_id : str ) -> None :
60- await self ._redis .sadd (self ._task_registry_name , task_id )
58+ task_started_event = asyncio .Event ()
59+ async def _wrapped_listen_and_relay () -> None :
60+ task_started_event .set ()
61+ c = EventConsumer (self ._local_queue [task_id ].tap ())
62+ async for event in c .consume_all ():
63+ logger .debug (f'Publishing event for task { task_id } in QM { self } : { event } ' )
64+ await self ._redis .publish (
65+ self ._task_channel_name (task_id ),
66+ event .model_dump_json (exclude_none = True ),
67+ )
68+
6169 self ._background_tasks [task_id ] = asyncio .create_task (
62- self . _listen_and_relay ( task_id )
70+ _wrapped_listen_and_relay ( )
6371 )
72+ await task_started_event .wait ()
73+ await self ._redis .sadd (self ._task_registry_name , task_id )
74+ logger .debug (f'Started to listen and relay events for task { task_id } ' )
6475
6576 async def _remove_task_id (self , task_id : str ) -> bool :
6677 if task_id in self ._background_tasks :
@@ -70,21 +81,51 @@ async def _remove_task_id(self, task_id: str) -> bool:
7081 return await self ._redis .srem (self ._task_registry_name , task_id ) == 1
7182
7283 async def _subscribe_remote_task_events (self , task_id : str ) -> None :
73- await self ._pubsub .subscribe (
74- ** {
75- self ._task_channel_name (task_id ): partial (
76- self ._relay_remote_events , task_id
77- )
78- }
79- )
80-
81- def _unsubscribe_remote_task_events (self , task_id : str ) -> None :
82- self ._pubsub .unsubscribe (self ._task_channel_name (task_id ))
83-
84- def _relay_remote_events (self , task_id : str , event_json : str ) -> None :
85- if task_id in self ._proxy_queue :
86- event = Event .model_validate_json (event_json )
87- self ._proxy_queue [task_id ].enqueue_event (event )
84+ channel_id = self ._task_channel_name (task_id )
85+ await self ._pubsub .subscribe (** {channel_id : self ._relay_remote_events })
86+
87+ # this is a global listener to handle incoming pubsub events
88+ if not self ._pubsub_listener_task :
89+ logger .debug ('Creating pubsub listener task.' )
90+ self ._pubsub_listener_task = asyncio .create_task (self ._consume_pubsub_messages ())
91+
92+ logger .debug (f"Subscribed for remote events for task { task_id } " )
93+
94+ async def _consume_pubsub_messages (self ):
95+ async for _ in self ._pubsub .listen ():
96+ pass
97+
98+ async def _relay_remote_events (self , subscription_event ) -> None :
99+ if 'channel' not in subscription_event or 'data' not in subscription_event :
100+ logger .warning (f"channel or data is absent in subscription event: { subscription_event } " )
101+ return
102+
103+ channel_id : str = subscription_event ['channel' ].decode ('utf-8' )
104+ data_string : str = subscription_event ['data' ].decode ('utf-8' )
105+ task_id = channel_id .split ('.' )[- 1 ]
106+ if task_id not in self ._proxy_queue :
107+ logger .warning (f"task_id { task_id } not found in proxy queue" )
108+ return
109+
110+ try :
111+ logger .debug (f"Received event for task_id { task_id } in QM { self } : { data_string } " )
112+ event = TypeAdapter (Event ).validate_json (data_string )
113+ except Exception as e :
114+ logger .warning (f"Failed to parse event from subscription event: { subscription_event } : { e } " )
115+ return
116+
117+ logger .debug (f"Enqueuing event for task_id { task_id } in QM { self } : { event } " )
118+ await self ._proxy_queue [task_id ].enqueue_event (event )
119+
120+
121+ async def _unsubscribe_remote_task_events (self , task_id : str ) -> None :
122+ # unsubscribe channel for given task_id
123+ await self ._pubsub .unsubscribe (self ._task_channel_name (task_id ))
124+ # release global listener if not channel is subscribed
125+ async with self ._lock :
126+ if not self ._pubsub .subscribed and self ._pubsub_listener_task :
127+ self ._pubsub_listener_task .cancel ()
128+ self ._pubsub_listener_task = None
88129
89130 async def add (self , task_id : str , queue : EventQueue ) -> None :
90131 """Add a new local event queue for the specified task.
@@ -96,11 +137,13 @@ async def add(self, task_id: str, queue: EventQueue) -> None:
96137 Raises:
97138 TaskQueueExists: If a queue for the task already exists.
98139 """
140+ logger .debug (f"add { task_id } " )
99141 async with self ._lock :
100142 if await self ._has_task_id (task_id ):
101143 raise TaskQueueExists ()
102144 self ._local_queue [task_id ] = queue
103145 await self ._register_task_id (task_id )
146+ logger .debug (f"Local queue is created for task { task_id } " )
104147
105148 async def get (self , task_id : str ) -> EventQueue | None :
106149 """Get the event queue associated with the given task ID.
@@ -115,17 +158,22 @@ async def get(self, task_id: str) -> EventQueue | None:
115158 Returns:
116159 EventQueue | None: The event queue if found, otherwise None.
117160 """
161+ logger .debug (f"get { task_id } " )
118162 async with self ._lock :
119163 # lookup locally
120164 if task_id in self ._local_queue :
165+ logger .debug (f"Got local queue for task_id { task_id } " )
121166 return self ._local_queue [task_id ]
122167 # lookup globally
123168 if await self ._has_task_id (task_id ):
124169 if task_id not in self ._proxy_queue :
170+ logger .debug (f"Creating proxy queue for { task_id } " )
125171 queue = EventQueue ()
126172 self ._proxy_queue [task_id ] = queue
127173 await self ._subscribe_remote_task_events (task_id )
174+ logger .debug (f"Got proxy queue for task_id { task_id } " )
128175 return self ._proxy_queue [task_id ]
176+ logger .warning (f"Attempted to get non-existing queue for task { task_id } " )
129177 return None
130178
131179 async def tap (self , task_id : str ) -> EventQueue | None :
@@ -137,8 +185,10 @@ async def tap(self, task_id: str) -> EventQueue | None:
137185 Returns:
138186 EventQueue | None: A new reference to the event queue if it exists, otherwise None.
139187 """
188+ logger .debug (f"tap { task_id } " )
140189 event_queue = await self .get (task_id )
141190 if event_queue :
191+ logger .debug (f'Tapping event queue for task: { task_id } ' )
142192 return event_queue .tap ()
143193 return None
144194
@@ -155,23 +205,27 @@ async def close(self, task_id: str) -> None:
155205 Raises:
156206 NoTaskQueue: If no queue exists for the given task ID.
157207 """
208+ logger .debug (f"close { task_id } " )
158209 async with self ._lock :
159210 if task_id in self ._local_queue :
160211 # close locally
161212 queue = self ._local_queue .pop (task_id )
162213 await queue .close ()
163214 # remove from global registry if a local queue is closed
164215 await self ._remove_task_id (task_id )
216+ logger .debug (f"Closing local queue for task { task_id } " )
165217 return
166218
167219 if task_id in self ._proxy_queue :
168220 # close proxy queue
169221 queue = self ._proxy_queue .pop (task_id )
170222 await queue .close ()
171223 # unsubscribe from remote, but don't remove from global registry
172- self ._unsubscribe_remote_task_events (task_id )
224+ await self ._unsubscribe_remote_task_events (task_id )
225+ logger .debug (f"Closing proxy queue for task { task_id } " )
173226 return
174227
228+ logger .warning (f"Attempted to close non-existing queue found for task { task_id } " )
175229 raise NoTaskQueue ()
176230
177231 async def create_or_tap (self , task_id : str ) -> EventQueue :
@@ -186,23 +240,28 @@ async def create_or_tap(self, task_id: str) -> EventQueue:
186240 Returns:
187241 EventQueue: An event queue associated with the given task ID.
188242 """
243+ logger .debug (f"create_or_tap { task_id } " )
189244 async with self ._lock :
190245 if await self ._has_task_id (task_id ):
191246 # if it's a local queue, tap directly
192247 if task_id in self ._local_queue :
248+ logger .debug (f"Tapping a local queue for task { task_id } " )
193249 return self ._local_queue [task_id ].tap ()
194250
195251 # if it's a proxy queue, tap the proxy
196252 if task_id in self ._proxy_queue :
253+ logger .debug (f"Tapping a proxy queue for task { task_id } " )
197254 return self ._proxy_queue [task_id ].tap ()
198255
199256 # if the proxy is not created, create the proxy and return
200257 queue = EventQueue ()
201258 self ._proxy_queue [task_id ] = queue
202259 await self ._subscribe_remote_task_events (task_id )
260+ logger .debug (f"Creating a proxy queue for task { task_id } " )
203261 return self ._proxy_queue [task_id ]
204262 # the task doesn't exist before, create a local queue
205263 queue = EventQueue ()
206264 self ._local_queue [task_id ] = queue
207265 await self ._register_task_id (task_id )
266+ logger .debug (f"Creating a local queue for task { task_id } " )
208267 return queue
0 commit comments