Skip to content

Commit d1a7261

Browse files
committed
fix bugs: completely addressed bugs about redis queue
1 parent 0911ced commit d1a7261

File tree

5 files changed

+30
-61
lines changed

5 files changed

+30
-61
lines changed

examples/mem_scheduler/api_w_scheduler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ def my_test_handler(messages: list[ScheduleMessageItem]):
2525
print(f"My test handler received {len(messages)} messages:")
2626
for msg in messages:
2727
print(f" my_test_handler - {msg.item_id}: {msg.content}")
28-
print(
29-
f"{queue._redis_conn.xinfo_groups(queue.stream_key_prefix)} qsize: {queue.qsize()} messages:{messages}"
30-
)
3128

3229

3330
# 2. Register the handler

src/memos/mem_scheduler/task_schedule_modules/dispatcher.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
150150
self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time())
151151

152152
# acknowledge redis messages
153-
154153
if self.use_redis_queue and self.memos_message_queue is not None:
155154
for msg in messages:
156155
redis_message_id = msg.redis_message_id

src/memos/mem_scheduler/task_schedule_modules/local_queue.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,7 @@ def put(
7979

8080
def get(
8181
self,
82-
user_id: str,
83-
mem_cube_id: str,
82+
stream_key: str,
8483
block: bool = True,
8584
timeout: float | None = None,
8685
batch_size: int | None = None,
@@ -91,8 +90,6 @@ def get(
9190
)
9291
return []
9392

94-
stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id)
95-
9693
# Return empty list if queue does not exist
9794
if stream_key not in self.queue_streams:
9895
logger.error(f"Stream {stream_key} does not exist when trying to get messages.")

src/memos/mem_scheduler/task_schedule_modules/redis_queue.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
the local memos_message_queue functionality in BaseScheduler.
66
"""
77

8+
import re
89
import time
910

1011
from collections.abc import Callable
@@ -165,8 +166,7 @@ def ack_message(self, user_id, mem_cube_id, redis_message_id) -> None:
165166

166167
def get(
167168
self,
168-
user_id: str,
169-
mem_cube_id: str,
169+
stream_key: str,
170170
block: bool = True,
171171
timeout: float | None = None,
172172
batch_size: int | None = None,
@@ -175,8 +175,6 @@ def get(
175175
raise ConnectionError("Not connected to Redis. Redis connection not available.")
176176

177177
try:
178-
stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id)
179-
180178
# Calculate timeout for Redis
181179
redis_timeout = None
182180
if block and timeout is not None:
@@ -295,17 +293,21 @@ def get_stream_keys(self) -> list[str]:
295293
if not self._redis_conn:
296294
return []
297295

298-
try:
299-
# Use match parameter and decode byte strings to regular strings
300-
stream_keys = [
301-
key.decode("utf-8") if isinstance(key, bytes) else key
302-
for key in self._redis_conn.scan_iter(match=f"{self.stream_key_prefix}:*")
303-
]
304-
logger.debug(f"get stream_keys from redis: {stream_keys}")
305-
return stream_keys
306-
except Exception as e:
307-
logger.error(f"Failed to list Redis stream keys: {e}")
308-
return []
296+
# First, get all keys that might match (using Redis pattern matching)
297+
redis_pattern = f"{self.stream_key_prefix}:*"
298+
raw_keys = [
299+
key.decode("utf-8") if isinstance(key, bytes) else key
300+
for key in self._redis_conn.scan_iter(match=redis_pattern)
301+
]
302+
303+
# Second, filter using Python regex to ensure exact prefix match
304+
# Escape special regex characters in the prefix, then add :.*
305+
escaped_prefix = re.escape(self.stream_key_prefix)
306+
regex_pattern = f"^{escaped_prefix}:"
307+
stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)]
308+
309+
logger.debug(f"get stream_keys from redis: {stream_keys}")
310+
return stream_keys
309311

310312
def size(self) -> int:
311313
"""

src/memos/mem_scheduler/task_schedule_modules/task_queue.py

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
the local memos_message_queue functionality in BaseScheduler.
66
"""
77

8-
from collections import defaultdict
9-
108
from memos.log import get_logger
119
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
1210
from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
@@ -58,9 +56,10 @@ def debug_mode_on(self):
5856

5957
def get_stream_keys(self) -> list[str]:
6058
if isinstance(self.memos_message_queue, SchedulerRedisQueue):
61-
return self.memos_message_queue.get_stream_keys()
59+
stream_keys = self.memos_message_queue.get_stream_keys()
6260
else:
63-
return list(self.memos_message_queue.queue_streams.keys())
61+
stream_keys = list(self.memos_message_queue.queue_streams.keys())
62+
return stream_keys
6463

6564
def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]):
6665
"""Submit messages to the message queue (either local queue or Redis)."""
@@ -98,46 +97,21 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
9897
)
9998

10099
def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
101-
# Discover all active streams via queue API
102-
streams: list[tuple[str, str]] = []
103-
104100
stream_keys = self.get_stream_keys()
105-
for stream_key in stream_keys:
106-
try:
107-
parts = stream_key.split(":")
108-
if len(parts) >= 3:
109-
user_id = parts[-2]
110-
mem_cube_id = parts[-1]
111-
streams.append((user_id, mem_cube_id))
112-
except Exception as e:
113-
logger.debug(f"Failed to parse stream key {stream_key}: {e}")
114-
115-
if not streams:
101+
102+
if len(stream_keys) == 0:
116103
return []
117104

118105
messages: list[ScheduleMessageItem] = []
119106

120-
# Group by user: {user_id: [mem_cube_id, ...]}
121-
122-
streams_by_user: dict[str, list[str]] = defaultdict(list)
123-
for user_id, mem_cube_id in streams:
124-
streams_by_user[user_id].append(mem_cube_id)
125-
126-
# For each user, fairly consume up to batch_size across their streams
127-
for user_id, mem_cube_ids in streams_by_user.items():
128-
if not mem_cube_ids:
129-
continue
130-
131-
# First pass: give each stream an equal share for this user
132-
for mem_cube_id in mem_cube_ids:
133-
fetched = self.memos_message_queue.get(
134-
user_id=user_id,
135-
mem_cube_id=mem_cube_id,
136-
block=False,
137-
batch_size=batch_size,
138-
)
107+
for stream_key in stream_keys:
108+
fetched = self.memos_message_queue.get(
109+
stream_key=stream_key,
110+
block=False,
111+
batch_size=batch_size,
112+
)
139113

140-
messages.extend(fetched)
114+
messages.extend(fetched)
141115

142116
logger.info(
143117
f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}"

0 commit comments

Comments
 (0)