Skip to content

Commit 1fe3ccc

Browse files
author
long.qul
committed
feat(server): implement task queue with TTL and node affiliation
- Add TTL (Time To Live) for task IDs in the global registry - Implement node affiliation for task IDs to prevent message broadcasting to outdated task queues - Update task registration and message relaying logic to support the new features - Add test cases for task ID expiration and node affiliation
1 parent ad5ca3c commit 1fe3ccc

File tree

4 files changed

+149
-80
lines changed

4 files changed

+149
-80
lines changed

.github/actions/spelling/allow.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,16 @@ pypistats
6565
pyversions
6666
respx
6767
resub
68+
sadd
69+
sismember
6870
socio
71+
srem
6972
sse
7073
tagwords
7174
taskupdate
7275
testuuid
7376
typeerror
7477
vulnz
75-
sadd
76-
sismember
77-
srem
7878
zadd
79-
zscore
8079
zrem
81-
zremrangebyscore
80+
zremrangebyscorezscore

src/a2a/server/events/event_queue.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,6 @@ async def close(self) -> None:
153153
await child.close()
154154
# Otherwise, join the queue
155155
else:
156-
# drain the queue or self.queue.join() would wait forever. This makes this piece of code equivalent to self.queue.shutdown() in python 3.13+
157-
while not self.queue.empty():
158-
await self.queue.get()
159-
self.queue.task_done()
160156
tasks = [asyncio.create_task(self.queue.join())]
161157
for child in self._children:
162158
tasks.append(asyncio.create_task(child.close()))
Lines changed: 95 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import logging
3-
import random
4-
import time
3+
import uuid
54

65
from asyncio import Task
76
from typing import Any
@@ -22,8 +21,6 @@
2221
logger = logging.getLogger(__name__)
2322

2423

25-
CLEAN_EXPIRED_PROBABILITY = 0.5
26-
2724
class RedisQueueManager(QueueManager):
2825
"""This implements the `QueueManager` interface using Redis for event.
2926
@@ -33,6 +30,8 @@ class RedisQueueManager(QueueManager):
3330
redis_client(Redis): asyncio redis connection.
3431
relay_channel_key_prefix(str): prefix for pubsub channel key generation.
3532
task_registry_key(str): key for set data where stores active `task_id`s.
33+
task_id_ttl_in_second: TTL for task id in global registry
34+
node_id: A unique id to be associated with task id in global registry. If node id is not matched, events won't be populated to queues in other `RedisQueueManager`s.
3635
"""
3736

3837
def __init__(
@@ -41,6 +40,7 @@ def __init__(
4140
relay_channel_key_prefix: str = 'a2a.event.relay.',
4241
task_registry_key: str = 'a2a.event.registry',
4342
task_id_ttl_in_second: int = 60 * 60 * 24,
43+
node_id: str = str(uuid.uuid4()),
4444
):
4545
self._redis = redis_client
4646
self._local_queue: dict[str, EventQueue] = {}
@@ -52,31 +52,60 @@ def __init__(
5252
self._task_registry_name = task_registry_key
5353
self._pubsub_listener_task: Task | None = None
5454
self._task_id_ttl_in_second = task_id_ttl_in_second
55+
self._node_id = node_id
5556

5657
def _task_channel_name(self, task_id: str) -> str:
5758
return self._relay_channel_name + task_id
5859

5960
async def _has_task_id(self, task_id: str) -> bool:
60-
ret = await self._redis.zscore(self._task_registry_name, task_id)
61+
ret = await self._redis.hget(self._task_registry_name, task_id)
6162
return ret is not None
6263

6364
async def _register_task_id(self, task_id: str) -> None:
64-
assert await self._redis.zadd(self._task_registry_name, {task_id: time.time()}, nx=True), 'task_id should not exist in global registry: ' + task_id
65+
await self._redis.hsetex(
66+
name=self._task_registry_name,
67+
key=task_id,
68+
value=self._node_id,
69+
ex=self._task_id_ttl_in_second,
70+
)
71+
logger.debug(
72+
f'Registered task_id {task_id} to node {self._node_id} in registry.'
73+
)
6574
task_started_event = asyncio.Event()
75+
6676
async def _wrapped_listen_and_relay() -> None:
6777
task_started_event.set()
6878
c = EventConsumer(self._local_queue[task_id].tap())
6979
async for event in c.consume_all():
70-
logger.debug(f'Publishing event for task {task_id} in QM {self}: {event}')
71-
await self._redis.publish(
72-
self._task_channel_name(task_id),
73-
event.model_dump_json(exclude_none=True),
80+
logger.debug(
81+
f'Publishing event for task {task_id} in QM {self}: {event}'
7482
)
75-
# update TTL for task_id
76-
await self._update_task_id_ttl(task_id)
77-
# clean expired task_ids with certain possibility
78-
if random.random() < CLEAN_EXPIRED_PROBABILITY:
79-
await self._clean_expired_task_ids()
83+
expected_node_id = await self._redis.hget(
84+
self._task_registry_name, task_id
85+
)
86+
expected_node_id = (
87+
expected_node_id.decode('utf-8')
88+
if hasattr(expected_node_id, 'decode')
89+
else expected_node_id
90+
)
91+
if expected_node_id == self._node_id:
92+
# publish message
93+
await self._redis.publish(
94+
self._task_channel_name(task_id),
95+
event.model_dump_json(exclude_none=True),
96+
)
97+
# update TTL for task_id
98+
await self._redis.hsetex(
99+
name=self._task_registry_name,
100+
key=task_id,
101+
value=self._node_id,
102+
ex=self._task_id_ttl_in_second,
103+
)
104+
else:
105+
logger.error(
106+
f'Task {task_id} is not registered on this node. Expected node id: {expected_node_id}'
107+
)
108+
break
80109

81110
self._background_tasks[task_id] = asyncio.create_task(
82111
_wrapped_listen_and_relay()
@@ -89,55 +118,56 @@ async def _remove_task_id(self, task_id: str) -> bool:
89118
self._background_tasks[task_id].cancel(
90119
'task_id is closed: ' + task_id
91120
)
92-
return await self._redis.zrem(self._task_registry_name, task_id) == 1
93-
94-
async def _update_task_id_ttl(self, task_id: str) -> bool:
95-
ret = await self._redis.zadd(
96-
self._task_registry_name,
97-
{task_id: time.time()},
98-
xx=True
99-
)
100-
return ret is not None
101-
102-
async def _clean_expired_task_ids(self) -> None:
103-
count = await self._redis.zremrangebyscore(self._task_registry_name, 0, time.time() - self._task_id_ttl_in_second)
104-
logger.debug(f'Removed {count} expired task ids')
121+
return await self._redis.hdel(self._task_registry_name, task_id) == 1
105122

106123
async def _subscribe_remote_task_events(self, task_id: str) -> None:
107124
channel_id = self._task_channel_name(task_id)
108125
await self._pubsub.subscribe(**{channel_id: self._relay_remote_events})
109-
110126
# this is a global listener to handle incoming pubsub events
111127
if not self._pubsub_listener_task:
112128
logger.debug('Creating pubsub listener task.')
113-
self._pubsub_listener_task = asyncio.create_task(self._consume_pubsub_messages())
114-
115-
logger.debug(f"Subscribed for remote events for task {task_id}")
129+
self._pubsub_listener_task = asyncio.create_task(
130+
self._consume_pubsub_messages()
131+
)
132+
logger.debug(f'Subscribed for remote events for task {task_id}')
116133

117134
async def _consume_pubsub_messages(self) -> None:
118135
async for _ in self._pubsub.listen():
119136
pass
120137

121-
async def _relay_remote_events(self, subscription_event: dict[str, Any]) -> None:
122-
if 'channel' not in subscription_event or 'data' not in subscription_event:
123-
logger.warning(f"channel or data is absent in subscription event: {subscription_event}")
138+
async def _relay_remote_events(
139+
self, subscription_event: dict[str, Any]
140+
) -> None:
141+
if (
142+
'channel' not in subscription_event
143+
or 'data' not in subscription_event
144+
):
145+
logger.warning(
146+
f'channel or data is absent in subscription event: {subscription_event}'
147+
)
124148
return
125149

126150
channel_id: str = subscription_event['channel'].decode('utf-8')
127151
data_string: str = subscription_event['data'].decode('utf-8')
128152
task_id = channel_id.split('.')[-1]
129153
if task_id not in self._proxy_queue:
130-
logger.warning(f"task_id {task_id} not found in proxy queue")
154+
logger.warning(f'task_id {task_id} not found in proxy queue')
131155
return
132156

133157
try:
134-
logger.debug(f"Received event for task_id {task_id} in QM {self}: {data_string}")
158+
logger.debug(
159+
f'Received event for task_id {task_id} in QM {self}: {data_string}'
160+
)
135161
event: Event = TypeAdapter(Event).validate_json(data_string)
136162
except Exception as e:
137-
logger.warning(f"Failed to parse event from subscription event: {subscription_event}: {e}")
163+
logger.warning(
164+
f'Failed to parse event from subscription event: {subscription_event}: {e}'
165+
)
138166
return
139167

140-
logger.debug(f"Enqueuing event for task_id {task_id} in QM {self}: {event}")
168+
logger.debug(
169+
f'Enqueuing event for task_id {task_id} in QM {self}: {event}'
170+
)
141171
await self._proxy_queue[task_id].enqueue_event(event)
142172

143173
async def _unsubscribe_remote_task_events(self, task_id: str) -> None:
@@ -148,7 +178,6 @@ async def _unsubscribe_remote_task_events(self, task_id: str) -> None:
148178
self._pubsub_listener_task.cancel()
149179
self._pubsub_listener_task = None
150180

151-
152181
async def add(self, task_id: str, queue: EventQueue) -> None:
153182
"""Add a new local event queue for the specified task.
154183
@@ -159,13 +188,13 @@ async def add(self, task_id: str, queue: EventQueue) -> None:
159188
Raises:
160189
TaskQueueExists: If a queue for the task already exists.
161190
"""
162-
logger.debug(f"add {task_id}")
191+
logger.debug(f'add {task_id}')
163192
async with self._lock:
164193
if await self._has_task_id(task_id):
165194
raise TaskQueueExists()
166195
self._local_queue[task_id] = queue
167196
await self._register_task_id(task_id)
168-
logger.debug(f"Local queue is created for task {task_id}")
197+
logger.debug(f'Local queue is created for task {task_id}')
169198

170199
async def get(self, task_id: str) -> EventQueue | None:
171200
"""Get the event queue associated with the given task ID.
@@ -180,22 +209,24 @@ async def get(self, task_id: str) -> EventQueue | None:
180209
Returns:
181210
EventQueue | None: The event queue if found, otherwise None.
182211
"""
183-
logger.debug(f"get {task_id}")
212+
logger.debug(f'get {task_id}')
184213
async with self._lock:
185214
# lookup locally
186215
if task_id in self._local_queue:
187-
logger.debug(f"Got local queue for task_id {task_id}")
216+
logger.debug(f'Got local queue for task_id {task_id}')
188217
return self._local_queue[task_id]
189218
# lookup globally
190219
if await self._has_task_id(task_id):
191220
if task_id not in self._proxy_queue:
192-
logger.debug(f"Creating proxy queue for {task_id}")
221+
logger.debug(f'Creating proxy queue for {task_id}')
193222
queue = EventQueue()
194223
self._proxy_queue[task_id] = queue
195224
await self._subscribe_remote_task_events(task_id)
196-
logger.debug(f"Got proxy queue for task_id {task_id}")
225+
logger.debug(f'Got proxy queue for task_id {task_id}')
197226
return self._proxy_queue[task_id]
198-
logger.warning(f"Attempted to get non-existing queue for task {task_id}")
227+
logger.warning(
228+
f'Attempted to get non-existing queue for task {task_id}'
229+
)
199230
return None
200231

201232
async def tap(self, task_id: str) -> EventQueue | None:
@@ -207,7 +238,7 @@ async def tap(self, task_id: str) -> EventQueue | None:
207238
Returns:
208239
EventQueue | None: A new reference to the event queue if it exists, otherwise None.
209240
"""
210-
logger.debug(f"tap {task_id}")
241+
logger.debug(f'tap {task_id}')
211242
event_queue = await self.get(task_id)
212243
if event_queue:
213244
logger.debug(f'Tapping event queue for task: {task_id}')
@@ -227,15 +258,15 @@ async def close(self, task_id: str) -> None:
227258
Raises:
228259
NoTaskQueue: If no queue exists for the given task ID.
229260
"""
230-
logger.debug(f"close {task_id}")
261+
logger.debug(f'close {task_id}')
231262
async with self._lock:
232263
if task_id in self._local_queue:
233264
# remove from global registry if a local queue is closed
234265
await self._remove_task_id(task_id)
235266
# close locally
236267
queue = self._local_queue.pop(task_id)
237268
await queue.close()
238-
logger.debug(f"Closing local queue for task {task_id}")
269+
logger.debug(f'Closing local queue for task {task_id}')
239270
return
240271

241272
if task_id in self._proxy_queue:
@@ -244,10 +275,12 @@ async def close(self, task_id: str) -> None:
244275
await queue.close()
245276
# unsubscribe from remote, but don't remove from global registry
246277
await self._unsubscribe_remote_task_events(task_id)
247-
logger.debug(f"Closing proxy queue for task {task_id}")
278+
logger.debug(f'Closing proxy queue for task {task_id}')
248279
return
249280

250-
logger.warning(f"Attempted to close non-existing queue found for task {task_id}")
281+
logger.warning(
282+
f'Attempted to close non-existing queue found for task {task_id}'
283+
)
251284
raise NoTaskQueue()
252285

253286
async def create_or_tap(self, task_id: str) -> EventQueue:
@@ -262,28 +295,25 @@ async def create_or_tap(self, task_id: str) -> EventQueue:
262295
Returns:
263296
EventQueue: An event queue associated with the given task ID.
264297
"""
265-
logger.debug(f"create_or_tap {task_id}")
298+
logger.debug(f'create_or_tap {task_id}')
266299
async with self._lock:
267300
if await self._has_task_id(task_id):
268301
# if it's a local queue, tap directly
269302
if task_id in self._local_queue:
270-
logger.debug(f"Tapping a local queue for task {task_id}")
303+
logger.debug(f'Tapping a local queue for task {task_id}')
271304
return self._local_queue[task_id].tap()
272305

273306
# if it's a proxy queue, tap the proxy
274-
if task_id in self._proxy_queue:
275-
logger.debug(f"Tapping a proxy queue for task {task_id}")
276-
return self._proxy_queue[task_id].tap()
277-
278-
# if the proxy is not created, create the proxy and return
279-
queue = EventQueue()
280-
self._proxy_queue[task_id] = queue
281-
await self._subscribe_remote_task_events(task_id)
282-
logger.debug(f"Creating a proxy queue for task {task_id}")
283-
return self._proxy_queue[task_id]
307+
if task_id not in self._proxy_queue:
308+
# if the proxy is not created, create the proxy
309+
queue = EventQueue()
310+
self._proxy_queue[task_id] = queue
311+
await self._subscribe_remote_task_events(task_id)
312+
logger.debug(f'Tapping a proxy queue for task {task_id}')
313+
return self._proxy_queue[task_id].tap()
284314
# the task doesn't exist before, create a local queue
285315
queue = EventQueue()
286316
self._local_queue[task_id] = queue
287317
await self._register_task_id(task_id)
288-
logger.debug(f"Creating a local queue for task {task_id}")
318+
logger.debug(f'Creating a local queue for task {task_id}')
289319
return queue

0 commit comments

Comments
 (0)